litep2p/yamux/frame/
io.rs

1// Copyright (c) 2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11use super::{
12    header::{self, HeaderDecodeError},
13    Frame,
14};
15use crate::yamux::connection::Id;
16use futures::{prelude::*, ready};
17use std::{
18    fmt, io,
19    pin::Pin,
20    task::{Context, Poll},
21};
22
23/// Logging target for the file.
24const LOG_TARGET: &str = "litep2p::yamux";
25
26/// A [`Stream`] and writer of [`Frame`] values.
27#[derive(Debug)]
28pub(crate) struct Io<T> {
29    id: Id,
30    io: T,
31    read_state: ReadState,
32    write_state: WriteState,
33    max_body_len: usize,
34}
35
36impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
37    pub(crate) fn new(id: Id, io: T, max_frame_body_len: usize) -> Self {
38        Io {
39            id,
40            io,
41            read_state: ReadState::Init,
42            write_state: WriteState::Init,
43            max_body_len: max_frame_body_len,
44        }
45    }
46}
47
48/// The stages of writing a new `Frame`.
49enum WriteState {
50    Init,
51    Header {
52        header: [u8; header::HEADER_SIZE],
53        buffer: Vec<u8>,
54        offset: usize,
55    },
56    Body {
57        buffer: Vec<u8>,
58        offset: usize,
59    },
60}
61
62impl fmt::Debug for WriteState {
63    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64        match self {
65            WriteState::Init => f.write_str("(WriteState::Init)"),
66            WriteState::Header { offset, .. } => {
67                write!(f, "(WriteState::Header (offset {}))", offset)
68            }
69            WriteState::Body { offset, buffer } => {
70                write!(
71                    f,
72                    "(WriteState::Body (offset {}) (buffer-len {}))",
73                    offset,
74                    buffer.len()
75                )
76            }
77        }
78    }
79}
80
81impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Frame<()>> for Io<T> {
82    type Error = io::Error;
83
84    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85        let this = Pin::into_inner(self);
86        loop {
87            tracing::trace!(target: LOG_TARGET, "{}: write: {:?}", this.id, this.write_state);
88            match &mut this.write_state {
89                WriteState::Init => return Poll::Ready(Ok(())),
90                WriteState::Header {
91                    header,
92                    buffer,
93                    ref mut offset,
94                } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) {
95                    Poll::Pending => return Poll::Pending,
96                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
97                    Poll::Ready(Ok(n)) => {
98                        if n == 0 {
99                            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
100                        }
101                        *offset += n;
102                        if *offset == header.len() {
103                            if !buffer.is_empty() {
104                                let buffer = std::mem::take(buffer);
105                                this.write_state = WriteState::Body { buffer, offset: 0 };
106                            } else {
107                                this.write_state = WriteState::Init;
108                            }
109                        }
110                    }
111                },
112                WriteState::Body {
113                    buffer,
114                    ref mut offset,
115                } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) {
116                    Poll::Pending => return Poll::Pending,
117                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
118                    Poll::Ready(Ok(n)) => {
119                        if n == 0 {
120                            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
121                        }
122                        *offset += n;
123                        if *offset == buffer.len() {
124                            this.write_state = WriteState::Init;
125                        }
126                    }
127                },
128            }
129        }
130    }
131
132    fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> {
133        let header = header::encode(&f.header);
134        let buffer = f.body;
135        self.get_mut().write_state = WriteState::Header {
136            header,
137            buffer,
138            offset: 0,
139        };
140        Ok(())
141    }
142
143    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
144        let this = Pin::into_inner(self);
145        ready!(this.poll_ready_unpin(cx))?;
146        Pin::new(&mut this.io).poll_flush(cx)
147    }
148
149    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150        let this = Pin::into_inner(self);
151        ready!(this.poll_ready_unpin(cx))?;
152        Pin::new(&mut this.io).poll_close(cx)
153    }
154}
155
156/// The stages of reading a new `Frame`.
157enum ReadState {
158    /// Initial reading state.
159    Init,
160    /// Reading the frame header.
161    Header {
162        offset: usize,
163        buffer: [u8; header::HEADER_SIZE],
164    },
165    /// Reading the frame body.
166    Body {
167        header: header::Header<()>,
168        offset: usize,
169        buffer: Vec<u8>,
170    },
171}
172
173impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
174    type Item = Result<Frame<()>, FrameDecodeError>;
175
176    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
177        let this = &mut *self;
178        loop {
179            tracing::trace!(target: LOG_TARGET, "{}: read: {:?}", this.id, this.read_state);
180            match this.read_state {
181                ReadState::Init => {
182                    this.read_state = ReadState::Header {
183                        offset: 0,
184                        buffer: [0; header::HEADER_SIZE],
185                    };
186                }
187                ReadState::Header {
188                    ref mut offset,
189                    ref mut buffer,
190                } => {
191                    if *offset == header::HEADER_SIZE {
192                        let header = match header::decode(buffer) {
193                            Ok(hd) => hd,
194                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
195                        };
196
197                        tracing::trace!(target: LOG_TARGET, "{}: read: {}", this.id, header);
198
199                        if header.tag() != header::Tag::Data {
200                            this.read_state = ReadState::Init;
201                            return Poll::Ready(Some(Ok(Frame::new(header))));
202                        }
203
204                        let body_len = header.len().val() as usize;
205
206                        if body_len > this.max_body_len {
207                            return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(
208                                body_len,
209                            ))));
210                        }
211
212                        this.read_state = ReadState::Body {
213                            header,
214                            offset: 0,
215                            buffer: vec![0; body_len],
216                        };
217
218                        continue;
219                    }
220
221                    let buf = &mut buffer[*offset..header::HEADER_SIZE];
222                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
223                        0 => {
224                            if *offset == 0 {
225                                return Poll::Ready(None);
226                            }
227                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
228                            return Poll::Ready(Some(Err(e)));
229                        }
230                        n => *offset += n,
231                    }
232                }
233                ReadState::Body {
234                    ref header,
235                    ref mut offset,
236                    ref mut buffer,
237                } => {
238                    let body_len = header.len().val() as usize;
239
240                    if *offset == body_len {
241                        let h = header.clone();
242                        let v = std::mem::take(buffer);
243                        this.read_state = ReadState::Init;
244                        return Poll::Ready(Some(Ok(Frame { header: h, body: v })));
245                    }
246
247                    let buf = &mut buffer[*offset..body_len];
248                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
249                        0 => {
250                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
251                            return Poll::Ready(Some(Err(e)));
252                        }
253                        n => *offset += n,
254                    }
255                }
256            }
257        }
258    }
259}
260
261impl fmt::Debug for ReadState {
262    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
263        match self {
264            ReadState::Init => f.write_str("(ReadState::Init)"),
265            ReadState::Header { offset, .. } => {
266                write!(f, "(ReadState::Header (offset {}))", offset)
267            }
268            ReadState::Body {
269                header,
270                offset,
271                buffer,
272            } => {
273                write!(
274                    f,
275                    "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
276                    header,
277                    offset,
278                    buffer.len()
279                )
280            }
281        }
282    }
283}
284
285/// Possible errors while decoding a message frame.
286#[non_exhaustive]
287#[derive(Debug)]
288pub enum FrameDecodeError {
289    /// An I/O error.
290    Io(io::Error),
291    /// Decoding the frame header failed.
292    Header(HeaderDecodeError),
293    /// A data frame body length is larger than the configured maximum.
294    FrameTooLarge(usize),
295}
296
297impl PartialEq for FrameDecodeError {
298    fn eq(&self, other: &Self) -> bool {
299        match (self, other) {
300            (FrameDecodeError::Io(e1), FrameDecodeError::Io(e2)) => e1.kind() == e2.kind(),
301            (FrameDecodeError::Header(e1), FrameDecodeError::Header(e2)) => e1 == e2,
302            (FrameDecodeError::FrameTooLarge(n1), FrameDecodeError::FrameTooLarge(n2)) => n1 == n2,
303            _ => false,
304        }
305    }
306}
307
308impl std::fmt::Display for FrameDecodeError {
309    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
310        match self {
311            FrameDecodeError::Io(e) => write!(f, "i/o error: {}", e),
312            FrameDecodeError::Header(e) => write!(f, "decode error: {}", e),
313            FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n),
314        }
315    }
316}
317
318impl std::error::Error for FrameDecodeError {
319    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
320        match self {
321            FrameDecodeError::Io(e) => Some(e),
322            FrameDecodeError::Header(e) => Some(e),
323            FrameDecodeError::FrameTooLarge(_) => None,
324        }
325    }
326}
327
328impl From<std::io::Error> for FrameDecodeError {
329    fn from(e: std::io::Error) -> Self {
330        FrameDecodeError::Io(e)
331    }
332}
333
334impl From<HeaderDecodeError> for FrameDecodeError {
335    fn from(e: HeaderDecodeError) -> Self {
336        FrameDecodeError::Header(e)
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use quickcheck::{Arbitrary, Gen, QuickCheck};
344    use rand::RngCore;
345
346    impl Arbitrary for Frame<()> {
347        fn arbitrary(g: &mut Gen) -> Self {
348            let mut header: header::Header<()> = Arbitrary::arbitrary(g);
349            let body = if header.tag() == header::Tag::Data {
350                header.set_len(header.len().val() % 4096);
351                let mut b = vec![0; header.len().val() as usize];
352                rand::thread_rng().fill_bytes(&mut b);
353                b
354            } else {
355                Vec::new()
356            };
357            Frame { header, body }
358        }
359    }
360
361    #[test]
362    fn encode_decode_identity() {
363        fn property(f: Frame<()>) -> bool {
364            futures::executor::block_on(async move {
365                let id = crate::yamux::connection::Id::random();
366                let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len());
367                if io.send(f.clone()).await.is_err() {
368                    return false;
369                }
370                if io.flush().await.is_err() {
371                    return false;
372                }
373                io.io.set_position(0);
374                if let Ok(Some(x)) = io.try_next().await {
375                    x == f
376                } else {
377                    false
378                }
379            })
380        }
381
382        QuickCheck::new().tests(10_000).quickcheck(property as fn(Frame<()>) -> bool)
383    }
384}