yamux/frame/
io.rs

1// Copyright (c) 2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11use super::{
12    header::{self, HeaderDecodeError},
13    Frame,
14};
15use crate::connection::Id;
16use futures::{prelude::*, ready};
17use std::{
18    fmt, io,
19    pin::Pin,
20    task::{Context, Poll},
21};
22
23/// A [`Stream`] and writer of [`Frame`] values.
24#[derive(Debug)]
25pub(crate) struct Io<T> {
26    id: Id,
27    io: T,
28    read_state: ReadState,
29    write_state: WriteState,
30    max_body_len: usize,
31}
32
33impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
34    pub(crate) fn new(id: Id, io: T, max_frame_body_len: usize) -> Self {
35        Io {
36            id,
37            io,
38            read_state: ReadState::Init,
39            write_state: WriteState::Init,
40            max_body_len: max_frame_body_len,
41        }
42    }
43}
44
45/// The stages of writing a new `Frame`.
46enum WriteState {
47    Init,
48    Header {
49        header: [u8; header::HEADER_SIZE],
50        buffer: Vec<u8>,
51        offset: usize,
52    },
53    Body {
54        buffer: Vec<u8>,
55        offset: usize,
56    },
57}
58
59impl fmt::Debug for WriteState {
60    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61        match self {
62            WriteState::Init => f.write_str("(WriteState::Init)"),
63            WriteState::Header { offset, .. } => {
64                write!(f, "(WriteState::Header (offset {}))", offset)
65            }
66            WriteState::Body { offset, buffer } => {
67                write!(
68                    f,
69                    "(WriteState::Body (offset {}) (buffer-len {}))",
70                    offset,
71                    buffer.len()
72                )
73            }
74        }
75    }
76}
77
78impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Frame<()>> for Io<T> {
79    type Error = io::Error;
80
81    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
82        let this = Pin::into_inner(self);
83        loop {
84            log::trace!("{}: write: {:?}", this.id, this.write_state);
85            match &mut this.write_state {
86                WriteState::Init => return Poll::Ready(Ok(())),
87                WriteState::Header {
88                    header,
89                    buffer,
90                    ref mut offset,
91                } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) {
92                    Poll::Pending => return Poll::Pending,
93                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
94                    Poll::Ready(Ok(n)) => {
95                        if n == 0 {
96                            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
97                        }
98                        *offset += n;
99                        if *offset == header.len() {
100                            if !buffer.is_empty() {
101                                let buffer = std::mem::take(buffer);
102                                this.write_state = WriteState::Body { buffer, offset: 0 };
103                            } else {
104                                this.write_state = WriteState::Init;
105                            }
106                        }
107                    }
108                },
109                WriteState::Body {
110                    buffer,
111                    ref mut offset,
112                } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) {
113                    Poll::Pending => return Poll::Pending,
114                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
115                    Poll::Ready(Ok(n)) => {
116                        if n == 0 {
117                            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
118                        }
119                        *offset += n;
120                        if *offset == buffer.len() {
121                            this.write_state = WriteState::Init;
122                        }
123                    }
124                },
125            }
126        }
127    }
128
129    fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> {
130        let header = header::encode(&f.header);
131        let buffer = f.body;
132        self.get_mut().write_state = WriteState::Header {
133            header,
134            buffer,
135            offset: 0,
136        };
137        Ok(())
138    }
139
140    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141        let this = Pin::into_inner(self);
142        ready!(this.poll_ready_unpin(cx))?;
143        Pin::new(&mut this.io).poll_flush(cx)
144    }
145
146    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
147        let this = Pin::into_inner(self);
148        ready!(this.poll_ready_unpin(cx))?;
149        Pin::new(&mut this.io).poll_close(cx)
150    }
151}
152
153/// The stages of reading a new `Frame`.
154enum ReadState {
155    /// Initial reading state.
156    Init,
157    /// Reading the frame header.
158    Header {
159        offset: usize,
160        buffer: [u8; header::HEADER_SIZE],
161    },
162    /// Reading the frame body.
163    Body {
164        header: header::Header<()>,
165        offset: usize,
166        buffer: Vec<u8>,
167    },
168}
169
170impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
171    type Item = Result<Frame<()>, FrameDecodeError>;
172
173    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
174        let mut this = &mut *self;
175        loop {
176            log::trace!("{}: read: {:?}", this.id, this.read_state);
177            match this.read_state {
178                ReadState::Init => {
179                    this.read_state = ReadState::Header {
180                        offset: 0,
181                        buffer: [0; header::HEADER_SIZE],
182                    };
183                }
184                ReadState::Header {
185                    ref mut offset,
186                    ref mut buffer,
187                } => {
188                    if *offset == header::HEADER_SIZE {
189                        let header = match header::decode(buffer) {
190                            Ok(hd) => hd,
191                            Err(e) => return Poll::Ready(Some(Err(e.into()))),
192                        };
193
194                        log::trace!("{}: read: {}", this.id, header);
195
196                        if header.tag() != header::Tag::Data {
197                            this.read_state = ReadState::Init;
198                            return Poll::Ready(Some(Ok(Frame::new(header))));
199                        }
200
201                        let body_len = header.len().val() as usize;
202
203                        if body_len > this.max_body_len {
204                            return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(
205                                body_len,
206                            ))));
207                        }
208
209                        this.read_state = ReadState::Body {
210                            header,
211                            offset: 0,
212                            buffer: vec![0; body_len],
213                        };
214
215                        continue;
216                    }
217
218                    let buf = &mut buffer[*offset..header::HEADER_SIZE];
219                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
220                        0 => {
221                            if *offset == 0 {
222                                return Poll::Ready(None);
223                            }
224                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
225                            return Poll::Ready(Some(Err(e)));
226                        }
227                        n => *offset += n,
228                    }
229                }
230                ReadState::Body {
231                    ref header,
232                    ref mut offset,
233                    ref mut buffer,
234                } => {
235                    let body_len = header.len().val() as usize;
236
237                    if *offset == body_len {
238                        let h = header.clone();
239                        let v = std::mem::take(buffer);
240                        this.read_state = ReadState::Init;
241                        return Poll::Ready(Some(Ok(Frame { header: h, body: v })));
242                    }
243
244                    let buf = &mut buffer[*offset..body_len];
245                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
246                        0 => {
247                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
248                            return Poll::Ready(Some(Err(e)));
249                        }
250                        n => *offset += n,
251                    }
252                }
253            }
254        }
255    }
256}
257
258impl fmt::Debug for ReadState {
259    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
260        match self {
261            ReadState::Init => f.write_str("(ReadState::Init)"),
262            ReadState::Header { offset, .. } => {
263                write!(f, "(ReadState::Header (offset {}))", offset)
264            }
265            ReadState::Body {
266                header,
267                offset,
268                buffer,
269            } => {
270                write!(
271                    f,
272                    "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
273                    header,
274                    offset,
275                    buffer.len()
276                )
277            }
278        }
279    }
280}
281
282/// Possible errors while decoding a message frame.
283#[non_exhaustive]
284#[derive(Debug)]
285pub enum FrameDecodeError {
286    /// An I/O error.
287    Io(io::Error),
288    /// Decoding the frame header failed.
289    Header(HeaderDecodeError),
290    /// A data frame body length is larger than the configured maximum.
291    FrameTooLarge(usize),
292}
293
294impl std::fmt::Display for FrameDecodeError {
295    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
296        match self {
297            FrameDecodeError::Io(e) => write!(f, "i/o error: {}", e),
298            FrameDecodeError::Header(e) => write!(f, "decode error: {}", e),
299            FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n),
300        }
301    }
302}
303
304impl std::error::Error for FrameDecodeError {
305    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
306        match self {
307            FrameDecodeError::Io(e) => Some(e),
308            FrameDecodeError::Header(e) => Some(e),
309            FrameDecodeError::FrameTooLarge(_) => None,
310        }
311    }
312}
313
314impl From<std::io::Error> for FrameDecodeError {
315    fn from(e: std::io::Error) -> Self {
316        FrameDecodeError::Io(e)
317    }
318}
319
320impl From<HeaderDecodeError> for FrameDecodeError {
321    fn from(e: HeaderDecodeError) -> Self {
322        FrameDecodeError::Header(e)
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use quickcheck::{Arbitrary, Gen, QuickCheck};
330    use rand::RngCore;
331
332    impl Arbitrary for Frame<()> {
333        fn arbitrary(g: &mut Gen) -> Self {
334            let mut header: header::Header<()> = Arbitrary::arbitrary(g);
335            let body = if header.tag() == header::Tag::Data {
336                header.set_len(header.len().val() % 4096);
337                let mut b = vec![0; header.len().val() as usize];
338                rand::thread_rng().fill_bytes(&mut b);
339                b
340            } else {
341                Vec::new()
342            };
343            Frame { header, body }
344        }
345    }
346
347    #[test]
348    fn encode_decode_identity() {
349        fn property(f: Frame<()>) -> bool {
350            futures::executor::block_on(async move {
351                let id = crate::connection::Id::random();
352                let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len());
353                if io.send(f.clone()).await.is_err() {
354                    return false;
355                }
356                if io.flush().await.is_err() {
357                    return false;
358                }
359                io.io.set_position(0);
360                if let Ok(Some(x)) = io.try_next().await {
361                    x == f
362                } else {
363                    false
364                }
365            })
366        }
367
368        QuickCheck::new()
369            .tests(10_000)
370            .quickcheck(property as fn(Frame<()>) -> bool)
371    }
372}