1use futures::{ready, sink::Sink};
35use pin_project_lite::pin_project;
36use std::{
37 future::Future,
38 pin::Pin,
39 task::{Context, Poll},
40};
41
42pub(crate) fn make_sink<S, F, T, A, E>(init: S, f: F) -> SinkImpl<S, F, T, A, E>
49where
50 F: FnMut(S, Action<A>) -> T,
51 T: Future<Output = Result<S, E>>,
52{
53 SinkImpl {
54 lambda: f,
55 future: None,
56 param: Some(init),
57 state: State::Empty,
58 _mark: std::marker::PhantomData,
59 }
60}
61
62#[derive(Clone, Debug, PartialEq, Eq)]
68pub(crate) enum Action<A> {
69 Send(A),
72 Flush,
75 Close,
78}
79
80#[derive(Debug, PartialEq, Eq)]
82enum State {
83 Empty,
85 Sending,
87 Flushing,
89 Closing,
91 Closed,
93 Failed,
95}
96
97#[derive(Debug, thiserror::Error)]
99pub(crate) enum Error<E> {
100 #[error("Error while sending over the sink, {0}")]
101 Send(E),
102 #[error("The Sink has closed")]
103 Closed,
104}
105
106pin_project! {
107 #[derive(Debug)]
109 pub(crate) struct SinkImpl<S, F, T, A, E> {
110 lambda: F,
111 #[pin] future: Option<T>,
112 param: Option<S>,
113 state: State,
114 _mark: std::marker::PhantomData<(A, E)>
115 }
116}
117
118impl<S, F, T, A, E> Sink<A> for SinkImpl<S, F, T, A, E>
119where
120 F: FnMut(S, Action<A>) -> T,
121 T: Future<Output = Result<S, E>>,
122{
123 type Error = Error<E>;
124
125 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
126 let mut this = self.project();
127 match this.state {
128 State::Sending | State::Flushing => {
129 match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
130 Ok(p) => {
131 this.future.set(None);
132 *this.param = Some(p);
133 *this.state = State::Empty;
134 Poll::Ready(Ok(()))
135 }
136 Err(e) => {
137 this.future.set(None);
138 *this.state = State::Failed;
139 Poll::Ready(Err(Error::Send(e)))
140 }
141 }
142 }
143 State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
144 Ok(_) => {
145 this.future.set(None);
146 *this.state = State::Closed;
147 Poll::Ready(Err(Error::Closed))
148 }
149 Err(e) => {
150 this.future.set(None);
151 *this.state = State::Failed;
152 Poll::Ready(Err(Error::Send(e)))
153 }
154 },
155 State::Empty => {
156 assert!(this.param.is_some());
157 Poll::Ready(Ok(()))
158 }
159 State::Closed | State::Failed => Poll::Ready(Err(Error::Closed)),
160 }
161 }
162
163 fn start_send(self: Pin<&mut Self>, item: A) -> Result<(), Self::Error> {
164 assert_eq!(State::Empty, self.state);
165 let mut this = self.project();
166 let param = this.param.take().unwrap();
167 let future = (this.lambda)(param, Action::Send(item));
168 this.future.set(Some(future));
169 *this.state = State::Sending;
170 Ok(())
171 }
172
173 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
174 loop {
175 let mut this = self.as_mut().project();
176 match this.state {
177 State::Empty => {
178 if let Some(p) = this.param.take() {
179 let future = (this.lambda)(p, Action::Flush);
180 this.future.set(Some(future));
181 *this.state = State::Flushing
182 } else {
183 return Poll::Ready(Ok(()));
184 }
185 }
186 State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
187 {
188 Ok(p) => {
189 this.future.set(None);
190 *this.param = Some(p);
191 *this.state = State::Empty
192 }
193 Err(e) => {
194 this.future.set(None);
195 *this.state = State::Failed;
196 return Poll::Ready(Err(Error::Send(e)));
197 }
198 },
199 State::Flushing => {
200 match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
201 Ok(p) => {
202 this.future.set(None);
203 *this.param = Some(p);
204 *this.state = State::Empty;
205 return Poll::Ready(Ok(()));
206 }
207 Err(e) => {
208 this.future.set(None);
209 *this.state = State::Failed;
210 return Poll::Ready(Err(Error::Send(e)));
211 }
212 }
213 }
214 State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
215 {
216 Ok(_) => {
217 this.future.set(None);
218 *this.state = State::Closed;
219 return Poll::Ready(Ok(()));
220 }
221 Err(e) => {
222 this.future.set(None);
223 *this.state = State::Failed;
224 return Poll::Ready(Err(Error::Send(e)));
225 }
226 },
227 State::Closed | State::Failed => return Poll::Ready(Err(Error::Closed)),
228 }
229 }
230 }
231
232 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
233 loop {
234 let mut this = self.as_mut().project();
235 match this.state {
236 State::Empty => {
237 if let Some(p) = this.param.take() {
238 let future = (this.lambda)(p, Action::Close);
239 this.future.set(Some(future));
240 *this.state = State::Closing;
241 } else {
242 return Poll::Ready(Ok(()));
243 }
244 }
245 State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
246 {
247 Ok(p) => {
248 this.future.set(None);
249 *this.param = Some(p);
250 *this.state = State::Empty
251 }
252 Err(e) => {
253 this.future.set(None);
254 *this.state = State::Failed;
255 return Poll::Ready(Err(Error::Send(e)));
256 }
257 },
258 State::Flushing => {
259 match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
260 Ok(p) => {
261 this.future.set(None);
262 *this.param = Some(p);
263 *this.state = State::Empty
264 }
265 Err(e) => {
266 this.future.set(None);
267 *this.state = State::Failed;
268 return Poll::Ready(Err(Error::Send(e)));
269 }
270 }
271 }
272 State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
273 {
274 Ok(_) => {
275 this.future.set(None);
276 *this.state = State::Closed;
277 return Poll::Ready(Ok(()));
278 }
279 Err(e) => {
280 this.future.set(None);
281 *this.state = State::Failed;
282 return Poll::Ready(Err(Error::Send(e)));
283 }
284 },
285 State::Closed => return Poll::Ready(Ok(())),
286 State::Failed => return Poll::Ready(Err(Error::Closed)),
287 }
288 }
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use crate::quicksink::{make_sink, Action};
295 use async_std::{io, task};
296 use futures::{channel::mpsc, prelude::*, stream};
297
298 #[test]
299 fn smoke_test() {
300 task::block_on(async {
301 let sink = make_sink(io::stdout(), |mut stdout, action| async move {
302 match action {
303 Action::Send(x) => stdout.write_all(x).await?,
304 Action::Flush => stdout.flush().await?,
305 Action::Close => stdout.close().await?,
306 }
307 Ok::<_, io::Error>(stdout)
308 });
309
310 let values = vec![Ok(&b"hello\n"[..]), Ok(&b"world\n"[..])];
311 assert!(stream::iter(values).forward(sink).await.is_ok())
312 })
313 }
314
315 #[test]
316 fn replay() {
317 task::block_on(async {
318 let (tx, rx) = mpsc::channel(5);
319
320 let sink = make_sink(tx, |mut tx, action| async move {
321 tx.send(action.clone()).await?;
322 if action == Action::Close {
323 tx.close().await?
324 }
325 Ok::<_, mpsc::SendError>(tx)
326 });
327
328 futures::pin_mut!(sink);
329
330 let expected = [
331 Action::Send("hello\n"),
332 Action::Flush,
333 Action::Send("world\n"),
334 Action::Flush,
335 Action::Close,
336 ];
337
338 for &item in &["hello\n", "world\n"] {
339 sink.send(item).await.unwrap()
340 }
341
342 sink.close().await.unwrap();
343
344 let actual = rx.collect::<Vec<_>>().await;
345
346 assert_eq!(&expected[..], &actual[..])
347 });
348 }
349
350 #[test]
351 fn error_does_not_panic() {
352 task::block_on(async {
353 let sink = make_sink(io::stdout(), |mut _stdout, _action| async move {
354 Err(io::Error::new(io::ErrorKind::Other, "oh no"))
355 });
356
357 futures::pin_mut!(sink);
358
359 let result = sink.send("hello").await;
360 match result {
361 Err(crate::quicksink::Error::Send(e)) => {
362 assert_eq!(e.kind(), io::ErrorKind::Other);
363 assert_eq!(e.to_string(), "oh no")
364 }
365 _ => panic!("unexpected result: {:?}", result),
366 };
367
368 let result = sink.send("hello").await;
370 match result {
371 Err(crate::quicksink::Error::Closed) => {}
372 _ => panic!("unexpected result: {:?}", result),
373 };
374 })
375 }
376}