1use super::{
12 header::{self, HeaderDecodeError},
13 Frame,
14};
15use crate::connection::Id;
16use futures::{prelude::*, ready};
17use std::{
18 fmt, io,
19 pin::Pin,
20 task::{Context, Poll},
21};
22
23const MAX_FRAME_BODY_LEN: usize = crate::MIB;
29
30#[derive(Debug)]
32pub(crate) struct Io<T> {
33 id: Id,
34 io: T,
35 read_state: ReadState,
36 write_state: WriteState,
37}
38
39impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
40 pub(crate) fn new(id: Id, io: T) -> Self {
41 Io {
42 id,
43 io,
44 read_state: ReadState::Init,
45 write_state: WriteState::Init,
46 }
47 }
48}
49
50enum WriteState {
52 Init,
53 Header {
54 header: [u8; header::HEADER_SIZE],
55 buffer: Vec<u8>,
56 offset: usize,
57 },
58 Body {
59 buffer: Vec<u8>,
60 offset: usize,
61 },
62 Poisoned,
63}
64
65impl fmt::Debug for WriteState {
66 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67 match self {
68 WriteState::Init => f.write_str("(WriteState::Init)"),
69 WriteState::Header { offset, .. } => {
70 write!(f, "(WriteState::Header (offset {offset}))")
71 }
72 WriteState::Body { offset, buffer } => {
73 write!(
74 f,
75 "(WriteState::Body (offset {}) (buffer-len {}))",
76 offset,
77 buffer.len()
78 )
79 }
80 WriteState::Poisoned => f.write_str("(WriteState::Poisoned)"),
81 }
82 }
83}
84
85impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Frame<()>> for Io<T> {
86 type Error = io::Error;
87
88 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89 let this = Pin::into_inner(self);
90 loop {
91 log::trace!("{}: write: {:?}", this.id, this.write_state);
92 match &mut this.write_state {
93 WriteState::Init => return Poll::Ready(Ok(())),
94 WriteState::Header {
95 header,
96 buffer,
97 ref mut offset,
98 } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) {
99 Poll::Pending => return Poll::Pending,
100 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
101 Poll::Ready(Ok(n)) => {
102 if n == 0 {
103 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
104 }
105 *offset += n;
106
107 if *offset > header.len() {
108 let err = io::Error::other(format!(
109 "Writer header returned invalid write count n={n}: {offset} > {} ",
110 header.len(),
111 ));
112
113 this.write_state = WriteState::Poisoned;
114
115 return Poll::Ready(Err(err));
116 }
117
118 if *offset == header.len() {
119 if !buffer.is_empty() {
120 let buffer = std::mem::take(buffer);
121 this.write_state = WriteState::Body { buffer, offset: 0 };
122 } else {
123 this.write_state = WriteState::Init;
124 }
125 }
126 }
127 },
128 WriteState::Body {
129 buffer,
130 ref mut offset,
131 } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) {
132 Poll::Pending => return Poll::Pending,
133 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
134 Poll::Ready(Ok(n)) => {
135 if n == 0 {
136 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
137 }
138 *offset += n;
139
140 if *offset > buffer.len() {
141 let err = io::Error::other(format!(
142 "Writer body returned invalid write count n={n}: {offset} > {} ",
143 buffer.len(),
144 ));
145
146 this.write_state = WriteState::Poisoned;
147
148 return Poll::Ready(Err(err));
149 }
150
151 if *offset == buffer.len() {
152 this.write_state = WriteState::Init;
153 }
154 }
155 },
156 WriteState::Poisoned => {
157 return Poll::Ready(Err(io::Error::other(
158 "Sink is in poisoned state due to previous write error",
159 )))
160 }
161 }
162 }
163 }
164
165 fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> {
166 let header = header::encode(&f.header);
167 let buffer = f.body;
168 self.get_mut().write_state = WriteState::Header {
169 header,
170 buffer,
171 offset: 0,
172 };
173 Ok(())
174 }
175
176 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177 let this = Pin::into_inner(self);
178 ready!(this.poll_ready_unpin(cx))?;
179 Pin::new(&mut this.io).poll_flush(cx)
180 }
181
182 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
183 let this = Pin::into_inner(self);
184 ready!(this.poll_ready_unpin(cx))?;
185 Pin::new(&mut this.io).poll_close(cx)
186 }
187}
188
189enum ReadState {
191 Init,
193 Header {
195 offset: usize,
196 buffer: [u8; header::HEADER_SIZE],
197 },
198 Body {
200 header: header::Header<()>,
201 offset: usize,
202 buffer: Vec<u8>,
203 },
204}
205
206impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
207 type Item = Result<Frame<()>, FrameDecodeError>;
208
209 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
210 let this = &mut *self;
211 loop {
212 log::trace!("{}: read: {:?}", this.id, this.read_state);
213 match this.read_state {
214 ReadState::Init => {
215 this.read_state = ReadState::Header {
216 offset: 0,
217 buffer: [0; header::HEADER_SIZE],
218 };
219 }
220 ReadState::Header {
221 ref mut offset,
222 ref mut buffer,
223 } => {
224 if *offset == header::HEADER_SIZE {
225 let header = match header::decode(buffer) {
226 Ok(hd) => hd,
227 Err(e) => return Poll::Ready(Some(Err(e.into()))),
228 };
229
230 log::trace!("{}: read: {}", this.id, header);
231
232 if header.tag() != header::Tag::Data {
233 this.read_state = ReadState::Init;
234 return Poll::Ready(Some(Ok(Frame::new(header))));
235 }
236
237 let body_len = header.len().val() as usize;
238
239 if body_len > MAX_FRAME_BODY_LEN {
240 return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(
241 body_len,
242 ))));
243 }
244
245 this.read_state = ReadState::Body {
246 header,
247 offset: 0,
248 buffer: vec![0; body_len],
249 };
250
251 continue;
252 }
253
254 let buf = &mut buffer[*offset..header::HEADER_SIZE];
255 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
256 0 => {
257 if *offset == 0 {
258 return Poll::Ready(None);
259 }
260 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
261 return Poll::Ready(Some(Err(e)));
262 }
263 n => *offset += n,
264 }
265 }
266 ReadState::Body {
267 ref header,
268 ref mut offset,
269 ref mut buffer,
270 } => {
271 let body_len = header.len().val() as usize;
272
273 if *offset == body_len {
274 let h = header.clone();
275 let v = std::mem::take(buffer);
276 this.read_state = ReadState::Init;
277 return Poll::Ready(Some(Ok(Frame { header: h, body: v })));
278 }
279
280 let buf = &mut buffer[*offset..body_len];
281 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
282 0 => {
283 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
284 return Poll::Ready(Some(Err(e)));
285 }
286 n => *offset += n,
287 }
288 }
289 }
290 }
291 }
292}
293
294impl fmt::Debug for ReadState {
295 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
296 match self {
297 ReadState::Init => f.write_str("(ReadState::Init)"),
298 ReadState::Header { offset, .. } => {
299 write!(f, "(ReadState::Header (offset {offset}))")
300 }
301 ReadState::Body {
302 header,
303 offset,
304 buffer,
305 } => {
306 write!(
307 f,
308 "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
309 header,
310 offset,
311 buffer.len()
312 )
313 }
314 }
315 }
316}
317
318#[non_exhaustive]
320#[derive(Debug)]
321pub enum FrameDecodeError {
322 Io(io::Error),
324 Header(HeaderDecodeError),
326 FrameTooLarge(usize),
328}
329
330impl std::fmt::Display for FrameDecodeError {
331 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
332 match self {
333 FrameDecodeError::Io(e) => write!(f, "i/o error: {e}"),
334 FrameDecodeError::Header(e) => write!(f, "decode error: {e}"),
335 FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({n})"),
336 }
337 }
338}
339
340impl std::error::Error for FrameDecodeError {
341 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
342 match self {
343 FrameDecodeError::Io(e) => Some(e),
344 FrameDecodeError::Header(e) => Some(e),
345 FrameDecodeError::FrameTooLarge(_) => None,
346 }
347 }
348}
349
350impl From<std::io::Error> for FrameDecodeError {
351 fn from(e: std::io::Error) -> Self {
352 FrameDecodeError::Io(e)
353 }
354}
355
356impl From<HeaderDecodeError> for FrameDecodeError {
357 fn from(e: HeaderDecodeError) -> Self {
358 FrameDecodeError::Header(e)
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use quickcheck::{Arbitrary, Gen, QuickCheck};
366 use rand::RngCore;
367
368 impl Arbitrary for Frame<()> {
369 fn arbitrary(g: &mut Gen) -> Self {
370 let mut header: header::Header<()> = Arbitrary::arbitrary(g);
371 let body = if header.tag() == header::Tag::Data {
372 header.set_len(header.len().val() % 4096);
373 let mut b = vec![0; header.len().val() as usize];
374 rand::rng().fill_bytes(&mut b);
375 b
376 } else {
377 Vec::new()
378 };
379 Frame { header, body }
380 }
381 }
382
383 #[test]
384 fn encode_decode_identity() {
385 fn property(f: Frame<()>) -> bool {
386 futures::executor::block_on(async move {
387 let id = crate::connection::Id::random();
388 let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()));
389 if io.send(f.clone()).await.is_err() {
390 return false;
391 }
392 if io.flush().await.is_err() {
393 return false;
394 }
395 io.io.set_position(0);
396 if let Ok(Some(x)) = io.try_next().await {
397 x == f
398 } else {
399 false
400 }
401 })
402 }
403
404 QuickCheck::new()
405 .tests(10_000)
406 .quickcheck(property as fn(Frame<()>) -> bool)
407 }
408}