libp2p_websocket/
quicksink.rs

1// Copyright (c) 2019-2020 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8//
9// Forked into rust-libp2p and further distributed under the MIT license.
10
11// Create a [`Sink`] implementation from an initial value and a closure
12// returning a [`Future`].
13//
14// This is very similar to how `futures::stream::unfold` creates a `Stream`
15// implementation from a seed value and a future-returning closure.
16//
17// # Examples
18//
19// ```no_run
20// use async_std::io;
21// use futures::prelude::*;
22// use crate::quicksink::Action;
23//
24// crate::quicksink::make_sink(io::stdout(), |mut stdout, action| async move {
25//     match action {
26//         Action::Send(x) => stdout.write_all(x).await?,
27//         Action::Flush => stdout.flush().await?,
28//         Action::Close => stdout.close().await?
29//     }
30//     Ok::<_, io::Error>(stdout)
31// });
32// ```
33
34use futures::{ready, sink::Sink};
35use pin_project_lite::pin_project;
36use std::{
37    future::Future,
38    pin::Pin,
39    task::{Context, Poll},
40};
41
42/// Returns a `Sink` impl based on the initial value and the given closure.
43///
44/// The closure will be applied to the initial value and an [`Action`] that
45/// informs it about the action it should perform. The returned [`Future`]
46/// will resolve to another value and the process starts over using this
47/// output.
48pub(crate) fn make_sink<S, F, T, A, E>(init: S, f: F) -> SinkImpl<S, F, T, A, E>
49where
50    F: FnMut(S, Action<A>) -> T,
51    T: Future<Output = Result<S, E>>,
52{
53    SinkImpl {
54        lambda: f,
55        future: None,
56        param: Some(init),
57        state: State::Empty,
58        _mark: std::marker::PhantomData,
59    }
60}
61
62/// The command given to the closure so that it can perform appropriate action.
63///
64/// Presumably the closure encapsulates a resource to perform I/O. The commands
65/// correspond to methods of the [`Sink`] trait and provide the closure with
66/// sufficient information to know what kind of action to perform with it.
67#[derive(Clone, Debug, PartialEq, Eq)]
68pub(crate) enum Action<A> {
69    /// Send the given value.
70    /// Corresponds to [`Sink::start_send`].
71    Send(A),
72    /// Flush the resource.
73    /// Corresponds to [`Sink::poll_flush`].
74    Flush,
75    /// Close the resource.
76    /// Corresponds to [`Sink::poll_close`].
77    Close,
78}
79
80/// The various states the `Sink` may be in.
81#[derive(Debug, PartialEq, Eq)]
82enum State {
83    /// The `Sink` is idle.
84    Empty,
85    /// The `Sink` is sending a value.
86    Sending,
87    /// The `Sink` is flushing its resource.
88    Flushing,
89    /// The `Sink` is closing its resource.
90    Closing,
91    /// The `Sink` is closed (terminal state).
92    Closed,
93    /// The `Sink` experienced an error (terminal state).
94    Failed,
95}
96
97/// Errors the `Sink` may return.
98#[derive(Debug, thiserror::Error)]
99pub(crate) enum Error<E> {
100    #[error("Error while sending over the sink, {0}")]
101    Send(E),
102    #[error("The Sink has closed")]
103    Closed,
104}
105
106pin_project! {
107    /// `SinkImpl` implements the `Sink` trait.
108    #[derive(Debug)]
109    pub(crate) struct SinkImpl<S, F, T, A, E> {
110        lambda: F,
111        #[pin] future: Option<T>,
112        param: Option<S>,
113        state: State,
114        _mark: std::marker::PhantomData<(A, E)>
115    }
116}
117
118impl<S, F, T, A, E> Sink<A> for SinkImpl<S, F, T, A, E>
119where
120    F: FnMut(S, Action<A>) -> T,
121    T: Future<Output = Result<S, E>>,
122{
123    type Error = Error<E>;
124
125    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
126        let mut this = self.project();
127        match this.state {
128            State::Sending | State::Flushing => {
129                match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
130                    Ok(p) => {
131                        this.future.set(None);
132                        *this.param = Some(p);
133                        *this.state = State::Empty;
134                        Poll::Ready(Ok(()))
135                    }
136                    Err(e) => {
137                        this.future.set(None);
138                        *this.state = State::Failed;
139                        Poll::Ready(Err(Error::Send(e)))
140                    }
141                }
142            }
143            State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
144                Ok(_) => {
145                    this.future.set(None);
146                    *this.state = State::Closed;
147                    Poll::Ready(Err(Error::Closed))
148                }
149                Err(e) => {
150                    this.future.set(None);
151                    *this.state = State::Failed;
152                    Poll::Ready(Err(Error::Send(e)))
153                }
154            },
155            State::Empty => {
156                assert!(this.param.is_some());
157                Poll::Ready(Ok(()))
158            }
159            State::Closed | State::Failed => Poll::Ready(Err(Error::Closed)),
160        }
161    }
162
163    fn start_send(self: Pin<&mut Self>, item: A) -> Result<(), Self::Error> {
164        assert_eq!(State::Empty, self.state);
165        let mut this = self.project();
166        let param = this.param.take().unwrap();
167        let future = (this.lambda)(param, Action::Send(item));
168        this.future.set(Some(future));
169        *this.state = State::Sending;
170        Ok(())
171    }
172
173    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
174        loop {
175            let mut this = self.as_mut().project();
176            match this.state {
177                State::Empty => {
178                    if let Some(p) = this.param.take() {
179                        let future = (this.lambda)(p, Action::Flush);
180                        this.future.set(Some(future));
181                        *this.state = State::Flushing
182                    } else {
183                        return Poll::Ready(Ok(()));
184                    }
185                }
186                State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
187                {
188                    Ok(p) => {
189                        this.future.set(None);
190                        *this.param = Some(p);
191                        *this.state = State::Empty
192                    }
193                    Err(e) => {
194                        this.future.set(None);
195                        *this.state = State::Failed;
196                        return Poll::Ready(Err(Error::Send(e)));
197                    }
198                },
199                State::Flushing => {
200                    match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
201                        Ok(p) => {
202                            this.future.set(None);
203                            *this.param = Some(p);
204                            *this.state = State::Empty;
205                            return Poll::Ready(Ok(()));
206                        }
207                        Err(e) => {
208                            this.future.set(None);
209                            *this.state = State::Failed;
210                            return Poll::Ready(Err(Error::Send(e)));
211                        }
212                    }
213                }
214                State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
215                {
216                    Ok(_) => {
217                        this.future.set(None);
218                        *this.state = State::Closed;
219                        return Poll::Ready(Ok(()));
220                    }
221                    Err(e) => {
222                        this.future.set(None);
223                        *this.state = State::Failed;
224                        return Poll::Ready(Err(Error::Send(e)));
225                    }
226                },
227                State::Closed | State::Failed => return Poll::Ready(Err(Error::Closed)),
228            }
229        }
230    }
231
232    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
233        loop {
234            let mut this = self.as_mut().project();
235            match this.state {
236                State::Empty => {
237                    if let Some(p) = this.param.take() {
238                        let future = (this.lambda)(p, Action::Close);
239                        this.future.set(Some(future));
240                        *this.state = State::Closing;
241                    } else {
242                        return Poll::Ready(Ok(()));
243                    }
244                }
245                State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
246                {
247                    Ok(p) => {
248                        this.future.set(None);
249                        *this.param = Some(p);
250                        *this.state = State::Empty
251                    }
252                    Err(e) => {
253                        this.future.set(None);
254                        *this.state = State::Failed;
255                        return Poll::Ready(Err(Error::Send(e)));
256                    }
257                },
258                State::Flushing => {
259                    match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
260                        Ok(p) => {
261                            this.future.set(None);
262                            *this.param = Some(p);
263                            *this.state = State::Empty
264                        }
265                        Err(e) => {
266                            this.future.set(None);
267                            *this.state = State::Failed;
268                            return Poll::Ready(Err(Error::Send(e)));
269                        }
270                    }
271                }
272                State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
273                {
274                    Ok(_) => {
275                        this.future.set(None);
276                        *this.state = State::Closed;
277                        return Poll::Ready(Ok(()));
278                    }
279                    Err(e) => {
280                        this.future.set(None);
281                        *this.state = State::Failed;
282                        return Poll::Ready(Err(Error::Send(e)));
283                    }
284                },
285                State::Closed => return Poll::Ready(Ok(())),
286                State::Failed => return Poll::Ready(Err(Error::Closed)),
287            }
288        }
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use crate::quicksink::{make_sink, Action};
295    use async_std::{io, task};
296    use futures::{channel::mpsc, prelude::*, stream};
297
298    #[test]
299    fn smoke_test() {
300        task::block_on(async {
301            let sink = make_sink(io::stdout(), |mut stdout, action| async move {
302                match action {
303                    Action::Send(x) => stdout.write_all(x).await?,
304                    Action::Flush => stdout.flush().await?,
305                    Action::Close => stdout.close().await?,
306                }
307                Ok::<_, io::Error>(stdout)
308            });
309
310            let values = vec![Ok(&b"hello\n"[..]), Ok(&b"world\n"[..])];
311            assert!(stream::iter(values).forward(sink).await.is_ok())
312        })
313    }
314
315    #[test]
316    fn replay() {
317        task::block_on(async {
318            let (tx, rx) = mpsc::channel(5);
319
320            let sink = make_sink(tx, |mut tx, action| async move {
321                tx.send(action.clone()).await?;
322                if action == Action::Close {
323                    tx.close().await?
324                }
325                Ok::<_, mpsc::SendError>(tx)
326            });
327
328            futures::pin_mut!(sink);
329
330            let expected = [
331                Action::Send("hello\n"),
332                Action::Flush,
333                Action::Send("world\n"),
334                Action::Flush,
335                Action::Close,
336            ];
337
338            for &item in &["hello\n", "world\n"] {
339                sink.send(item).await.unwrap()
340            }
341
342            sink.close().await.unwrap();
343
344            let actual = rx.collect::<Vec<_>>().await;
345
346            assert_eq!(&expected[..], &actual[..])
347        });
348    }
349
350    #[test]
351    fn error_does_not_panic() {
352        task::block_on(async {
353            let sink = make_sink(io::stdout(), |mut _stdout, _action| async move {
354                Err(io::Error::new(io::ErrorKind::Other, "oh no"))
355            });
356
357            futures::pin_mut!(sink);
358
359            let result = sink.send("hello").await;
360            match result {
361                Err(crate::quicksink::Error::Send(e)) => {
362                    assert_eq!(e.kind(), io::ErrorKind::Other);
363                    assert_eq!(e.to_string(), "oh no")
364                }
365                _ => panic!("unexpected result: {:?}", result),
366            };
367
368            // Call send again, expect not to panic.
369            let result = sink.send("hello").await;
370            match result {
371                Err(crate::quicksink::Error::Closed) => {}
372                _ => panic!("unexpected result: {:?}", result),
373            };
374        })
375    }
376}