1use super::{
12 header::{self, HeaderDecodeError},
13 Frame,
14};
15use crate::connection::Id;
16use futures::{prelude::*, ready};
17use std::{
18 fmt, io,
19 pin::Pin,
20 task::{Context, Poll},
21};
22
23#[derive(Debug)]
25pub(crate) struct Io<T> {
26 id: Id,
27 io: T,
28 read_state: ReadState,
29 write_state: WriteState,
30 max_body_len: usize,
31}
32
33impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
34 pub(crate) fn new(id: Id, io: T, max_frame_body_len: usize) -> Self {
35 Io {
36 id,
37 io,
38 read_state: ReadState::Init,
39 write_state: WriteState::Init,
40 max_body_len: max_frame_body_len,
41 }
42 }
43}
44
45enum WriteState {
47 Init,
48 Header {
49 header: [u8; header::HEADER_SIZE],
50 buffer: Vec<u8>,
51 offset: usize,
52 },
53 Body {
54 buffer: Vec<u8>,
55 offset: usize,
56 },
57}
58
59impl fmt::Debug for WriteState {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 match self {
62 WriteState::Init => f.write_str("(WriteState::Init)"),
63 WriteState::Header { offset, .. } => {
64 write!(f, "(WriteState::Header (offset {}))", offset)
65 }
66 WriteState::Body { offset, buffer } => {
67 write!(
68 f,
69 "(WriteState::Body (offset {}) (buffer-len {}))",
70 offset,
71 buffer.len()
72 )
73 }
74 }
75 }
76}
77
78impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Frame<()>> for Io<T> {
79 type Error = io::Error;
80
81 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
82 let this = Pin::into_inner(self);
83 loop {
84 log::trace!("{}: write: {:?}", this.id, this.write_state);
85 match &mut this.write_state {
86 WriteState::Init => return Poll::Ready(Ok(())),
87 WriteState::Header {
88 header,
89 buffer,
90 ref mut offset,
91 } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) {
92 Poll::Pending => return Poll::Pending,
93 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
94 Poll::Ready(Ok(n)) => {
95 if n == 0 {
96 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
97 }
98 *offset += n;
99 if *offset == header.len() {
100 if !buffer.is_empty() {
101 let buffer = std::mem::take(buffer);
102 this.write_state = WriteState::Body { buffer, offset: 0 };
103 } else {
104 this.write_state = WriteState::Init;
105 }
106 }
107 }
108 },
109 WriteState::Body {
110 buffer,
111 ref mut offset,
112 } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) {
113 Poll::Pending => return Poll::Pending,
114 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
115 Poll::Ready(Ok(n)) => {
116 if n == 0 {
117 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
118 }
119 *offset += n;
120 if *offset == buffer.len() {
121 this.write_state = WriteState::Init;
122 }
123 }
124 },
125 }
126 }
127 }
128
129 fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> {
130 let header = header::encode(&f.header);
131 let buffer = f.body;
132 self.get_mut().write_state = WriteState::Header {
133 header,
134 buffer,
135 offset: 0,
136 };
137 Ok(())
138 }
139
140 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141 let this = Pin::into_inner(self);
142 ready!(this.poll_ready_unpin(cx))?;
143 Pin::new(&mut this.io).poll_flush(cx)
144 }
145
146 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
147 let this = Pin::into_inner(self);
148 ready!(this.poll_ready_unpin(cx))?;
149 Pin::new(&mut this.io).poll_close(cx)
150 }
151}
152
153enum ReadState {
155 Init,
157 Header {
159 offset: usize,
160 buffer: [u8; header::HEADER_SIZE],
161 },
162 Body {
164 header: header::Header<()>,
165 offset: usize,
166 buffer: Vec<u8>,
167 },
168}
169
170impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
171 type Item = Result<Frame<()>, FrameDecodeError>;
172
173 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
174 let mut this = &mut *self;
175 loop {
176 log::trace!("{}: read: {:?}", this.id, this.read_state);
177 match this.read_state {
178 ReadState::Init => {
179 this.read_state = ReadState::Header {
180 offset: 0,
181 buffer: [0; header::HEADER_SIZE],
182 };
183 }
184 ReadState::Header {
185 ref mut offset,
186 ref mut buffer,
187 } => {
188 if *offset == header::HEADER_SIZE {
189 let header = match header::decode(buffer) {
190 Ok(hd) => hd,
191 Err(e) => return Poll::Ready(Some(Err(e.into()))),
192 };
193
194 log::trace!("{}: read: {}", this.id, header);
195
196 if header.tag() != header::Tag::Data {
197 this.read_state = ReadState::Init;
198 return Poll::Ready(Some(Ok(Frame::new(header))));
199 }
200
201 let body_len = header.len().val() as usize;
202
203 if body_len > this.max_body_len {
204 return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(
205 body_len,
206 ))));
207 }
208
209 this.read_state = ReadState::Body {
210 header,
211 offset: 0,
212 buffer: vec![0; body_len],
213 };
214
215 continue;
216 }
217
218 let buf = &mut buffer[*offset..header::HEADER_SIZE];
219 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
220 0 => {
221 if *offset == 0 {
222 return Poll::Ready(None);
223 }
224 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
225 return Poll::Ready(Some(Err(e)));
226 }
227 n => *offset += n,
228 }
229 }
230 ReadState::Body {
231 ref header,
232 ref mut offset,
233 ref mut buffer,
234 } => {
235 let body_len = header.len().val() as usize;
236
237 if *offset == body_len {
238 let h = header.clone();
239 let v = std::mem::take(buffer);
240 this.read_state = ReadState::Init;
241 return Poll::Ready(Some(Ok(Frame { header: h, body: v })));
242 }
243
244 let buf = &mut buffer[*offset..body_len];
245 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
246 0 => {
247 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
248 return Poll::Ready(Some(Err(e)));
249 }
250 n => *offset += n,
251 }
252 }
253 }
254 }
255 }
256}
257
258impl fmt::Debug for ReadState {
259 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
260 match self {
261 ReadState::Init => f.write_str("(ReadState::Init)"),
262 ReadState::Header { offset, .. } => {
263 write!(f, "(ReadState::Header (offset {}))", offset)
264 }
265 ReadState::Body {
266 header,
267 offset,
268 buffer,
269 } => {
270 write!(
271 f,
272 "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
273 header,
274 offset,
275 buffer.len()
276 )
277 }
278 }
279 }
280}
281
282#[non_exhaustive]
284#[derive(Debug)]
285pub enum FrameDecodeError {
286 Io(io::Error),
288 Header(HeaderDecodeError),
290 FrameTooLarge(usize),
292}
293
294impl std::fmt::Display for FrameDecodeError {
295 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
296 match self {
297 FrameDecodeError::Io(e) => write!(f, "i/o error: {}", e),
298 FrameDecodeError::Header(e) => write!(f, "decode error: {}", e),
299 FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n),
300 }
301 }
302}
303
304impl std::error::Error for FrameDecodeError {
305 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
306 match self {
307 FrameDecodeError::Io(e) => Some(e),
308 FrameDecodeError::Header(e) => Some(e),
309 FrameDecodeError::FrameTooLarge(_) => None,
310 }
311 }
312}
313
314impl From<std::io::Error> for FrameDecodeError {
315 fn from(e: std::io::Error) -> Self {
316 FrameDecodeError::Io(e)
317 }
318}
319
320impl From<HeaderDecodeError> for FrameDecodeError {
321 fn from(e: HeaderDecodeError) -> Self {
322 FrameDecodeError::Header(e)
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use quickcheck::{Arbitrary, Gen, QuickCheck};
330 use rand::RngCore;
331
332 impl Arbitrary for Frame<()> {
333 fn arbitrary(g: &mut Gen) -> Self {
334 let mut header: header::Header<()> = Arbitrary::arbitrary(g);
335 let body = if header.tag() == header::Tag::Data {
336 header.set_len(header.len().val() % 4096);
337 let mut b = vec![0; header.len().val() as usize];
338 rand::thread_rng().fill_bytes(&mut b);
339 b
340 } else {
341 Vec::new()
342 };
343 Frame { header, body }
344 }
345 }
346
347 #[test]
348 fn encode_decode_identity() {
349 fn property(f: Frame<()>) -> bool {
350 futures::executor::block_on(async move {
351 let id = crate::connection::Id::random();
352 let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len());
353 if io.send(f.clone()).await.is_err() {
354 return false;
355 }
356 if io.flush().await.is_err() {
357 return false;
358 }
359 io.io.set_position(0);
360 if let Ok(Some(x)) = io.try_next().await {
361 x == f
362 } else {
363 false
364 }
365 })
366 }
367
368 QuickCheck::new()
369 .tests(10_000)
370 .quickcheck(property as fn(Frame<()>) -> bool)
371 }
372}