yamux/frame/
io.rs

1// Copyright (c) 2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11use super::{
12    header::{self, HeaderDecodeError},
13    Frame,
14};
15use crate::connection::Id;
16use futures::{prelude::*, ready};
17use std::{
18    fmt, io,
19    pin::Pin,
20    task::{Context, Poll},
21};
22
23/// Maximum Yamux frame body length
24///
25/// Limits the amount of bytes a remote can cause the local node to allocate at once when reading.
26///
27/// Chosen based on intuition in past iterations.
28const MAX_FRAME_BODY_LEN: usize = crate::MIB;
29
30/// A [`Stream`] and writer of [`Frame`] values.
31#[derive(Debug)]
32pub(crate) struct Io<T> {
33    id: Id,
34    io: T,
35    read_state: ReadState,
36    write_state: WriteState,
37}
38
39impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
40    pub(crate) fn new(id: Id, io: T) -> Self {
41        Io {
42            id,
43            io,
44            read_state: ReadState::Init,
45            write_state: WriteState::Init,
46        }
47    }
48}
49
50/// The stages of writing a new `Frame`.
51enum WriteState {
52    Init,
53    Header {
54        header: [u8; header::HEADER_SIZE],
55        buffer: Vec<u8>,
56        offset: usize,
57    },
58    Body {
59        buffer: Vec<u8>,
60        offset: usize,
61    },
62    Poisoned,
63}
64
65impl fmt::Debug for WriteState {
66    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67        match self {
68            WriteState::Init => f.write_str("(WriteState::Init)"),
69            WriteState::Header { offset, .. } => {
70                write!(f, "(WriteState::Header (offset {offset}))")
71            }
72            WriteState::Body { offset, buffer } => {
73                write!(
74                    f,
75                    "(WriteState::Body (offset {}) (buffer-len {}))",
76                    offset,
77                    buffer.len()
78                )
79            }
80            WriteState::Poisoned => f.write_str("(WriteState::Poisoned)"),
81        }
82    }
83}
84
85impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Frame<()>> for Io<T> {
86    type Error = io::Error;
87
88    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89        let this = Pin::into_inner(self);
90        loop {
91            log::trace!("{}: write: {:?}", this.id, this.write_state);
92            match &mut this.write_state {
93                WriteState::Init => return Poll::Ready(Ok(())),
94                WriteState::Header {
95                    header,
96                    buffer,
97                    ref mut offset,
98                } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) {
99                    Poll::Pending => return Poll::Pending,
100                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
101                    Poll::Ready(Ok(n)) => {
102                        if n == 0 {
103                            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
104                        }
105                        *offset += n;
106
107                        if *offset > header.len() {
108                            let err = io::Error::other(format!(
109                                "Writer header returned invalid write count n={n}: {offset} > {} ",
110                                header.len(),
111                            ));
112
113                            this.write_state = WriteState::Poisoned;
114
115                            return Poll::Ready(Err(err));
116                        }
117
118                        if *offset == header.len() {
119                            if !buffer.is_empty() {
120                                let buffer = std::mem::take(buffer);
121                                this.write_state = WriteState::Body { buffer, offset: 0 };
122                            } else {
123                                this.write_state = WriteState::Init;
124                            }
125                        }
126                    }
127                },
128                WriteState::Body {
129                    buffer,
130                    ref mut offset,
131                } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) {
132                    Poll::Pending => return Poll::Pending,
133                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
134                    Poll::Ready(Ok(n)) => {
135                        if n == 0 {
136                            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
137                        }
138                        *offset += n;
139
140                        if *offset > buffer.len() {
141                            let err = io::Error::other(format!(
142                                "Writer body returned invalid write count n={n}: {offset} > {} ",
143                                buffer.len(),
144                            ));
145
146                            this.write_state = WriteState::Poisoned;
147
148                            return Poll::Ready(Err(err));
149                        }
150
151                        if *offset == buffer.len() {
152                            this.write_state = WriteState::Init;
153                        }
154                    }
155                },
156                WriteState::Poisoned => {
157                    return Poll::Ready(Err(io::Error::other(
158                        "Sink is in poisoned state due to previous write error",
159                    )))
160                }
161            }
162        }
163    }
164
165    fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> {
166        let header = header::encode(&f.header);
167        let buffer = f.body;
168        self.get_mut().write_state = WriteState::Header {
169            header,
170            buffer,
171            offset: 0,
172        };
173        Ok(())
174    }
175
176    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        let this = Pin::into_inner(self);
178        ready!(this.poll_ready_unpin(cx))?;
179        Pin::new(&mut this.io).poll_flush(cx)
180    }
181
182    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
183        let this = Pin::into_inner(self);
184        ready!(this.poll_ready_unpin(cx))?;
185        Pin::new(&mut this.io).poll_close(cx)
186    }
187}
188
189/// The stages of reading a new `Frame`.
190enum ReadState {
191    /// Initial reading state.
192    Init,
193    /// Reading the frame header.
194    Header {
195        offset: usize,
196        buffer: [u8; header::HEADER_SIZE],
197    },
198    /// Reading the frame body.
199    Body {
200        header: header::Header<()>,
201        offset: usize,
202        buffer: Vec<u8>,
203    },
204}
205
206impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
207    type Item = Result<Frame<()>, FrameDecodeError>;
208
209    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
210        let this = &mut *self;
211        loop {
212            log::trace!("{}: read: {:?}", this.id, this.read_state);
213            match this.read_state {
214                ReadState::Init => {
215                    this.read_state = ReadState::Header {
216                        offset: 0,
217                        buffer: [0; header::HEADER_SIZE],
218                    };
219                }
220                ReadState::Header {
221                    ref mut offset,
222                    ref mut buffer,
223                } => {
224                    if *offset == header::HEADER_SIZE {
225                        let header = match header::decode(buffer) {
226                            Ok(hd) => hd,
227                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
228                        };
229
230                        log::trace!("{}: read: {}", this.id, header);
231
232                        if header.tag() != header::Tag::Data {
233                            this.read_state = ReadState::Init;
234                            return Poll::Ready(Some(Ok(Frame::new(header))));
235                        }
236
237                        let body_len = header.len().val() as usize;
238
239                        if body_len > MAX_FRAME_BODY_LEN {
240                            return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(
241                                body_len,
242                            ))));
243                        }
244
245                        this.read_state = ReadState::Body {
246                            header,
247                            offset: 0,
248                            buffer: vec![0; body_len],
249                        };
250
251                        continue;
252                    }
253
254                    let buf = &mut buffer[*offset..header::HEADER_SIZE];
255                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
256                        0 => {
257                            if *offset == 0 {
258                                return Poll::Ready(None);
259                            }
260                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
261                            return Poll::Ready(Some(Err(e)));
262                        }
263                        n => *offset += n,
264                    }
265                }
266                ReadState::Body {
267                    ref header,
268                    ref mut offset,
269                    ref mut buffer,
270                } => {
271                    let body_len = header.len().val() as usize;
272
273                    if *offset == body_len {
274                        let h = header.clone();
275                        let v = std::mem::take(buffer);
276                        this.read_state = ReadState::Init;
277                        return Poll::Ready(Some(Ok(Frame { header: h, body: v })));
278                    }
279
280                    let buf = &mut buffer[*offset..body_len];
281                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
282                        0 => {
283                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
284                            return Poll::Ready(Some(Err(e)));
285                        }
286                        n => *offset += n,
287                    }
288                }
289            }
290        }
291    }
292}
293
294impl fmt::Debug for ReadState {
295    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
296        match self {
297            ReadState::Init => f.write_str("(ReadState::Init)"),
298            ReadState::Header { offset, .. } => {
299                write!(f, "(ReadState::Header (offset {offset}))")
300            }
301            ReadState::Body {
302                header,
303                offset,
304                buffer,
305            } => {
306                write!(
307                    f,
308                    "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
309                    header,
310                    offset,
311                    buffer.len()
312                )
313            }
314        }
315    }
316}
317
318/// Possible errors while decoding a message frame.
319#[non_exhaustive]
320#[derive(Debug)]
321pub enum FrameDecodeError {
322    /// An I/O error.
323    Io(io::Error),
324    /// Decoding the frame header failed.
325    Header(HeaderDecodeError),
326    /// A data frame body length is larger than the configured maximum.
327    FrameTooLarge(usize),
328}
329
330impl std::fmt::Display for FrameDecodeError {
331    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
332        match self {
333            FrameDecodeError::Io(e) => write!(f, "i/o error: {e}"),
334            FrameDecodeError::Header(e) => write!(f, "decode error: {e}"),
335            FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({n})"),
336        }
337    }
338}
339
340impl std::error::Error for FrameDecodeError {
341    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
342        match self {
343            FrameDecodeError::Io(e) => Some(e),
344            FrameDecodeError::Header(e) => Some(e),
345            FrameDecodeError::FrameTooLarge(_) => None,
346        }
347    }
348}
349
350impl From<std::io::Error> for FrameDecodeError {
351    fn from(e: std::io::Error) -> Self {
352        FrameDecodeError::Io(e)
353    }
354}
355
356impl From<HeaderDecodeError> for FrameDecodeError {
357    fn from(e: HeaderDecodeError) -> Self {
358        FrameDecodeError::Header(e)
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use quickcheck::{Arbitrary, Gen, QuickCheck};
366    use rand::RngCore;
367
368    impl Arbitrary for Frame<()> {
369        fn arbitrary(g: &mut Gen) -> Self {
370            let mut header: header::Header<()> = Arbitrary::arbitrary(g);
371            let body = if header.tag() == header::Tag::Data {
372                header.set_len(header.len().val() % 4096);
373                let mut b = vec![0; header.len().val() as usize];
374                rand::rng().fill_bytes(&mut b);
375                b
376            } else {
377                Vec::new()
378            };
379            Frame { header, body }
380        }
381    }
382
383    #[test]
384    fn encode_decode_identity() {
385        fn property(f: Frame<()>) -> bool {
386            futures::executor::block_on(async move {
387                let id = crate::connection::Id::random();
388                let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()));
389                if io.send(f.clone()).await.is_err() {
390                    return false;
391                }
392                if io.flush().await.is_err() {
393                    return false;
394                }
395                io.io.set_position(0);
396                if let Ok(Some(x)) = io.try_next().await {
397                    x == f
398                } else {
399                    false
400                }
401            })
402        }
403
404        QuickCheck::new()
405            .tests(10_000)
406            .quickcheck(property as fn(Frame<()>) -> bool)
407    }
408}