1use super::{
12 header::{self, HeaderDecodeError},
13 Frame,
14};
15use crate::yamux::connection::Id;
16use futures::{prelude::*, ready};
17use std::{
18 fmt, io,
19 pin::Pin,
20 task::{Context, Poll},
21};
22
23const LOG_TARGET: &str = "litep2p::yamux";
25
26#[derive(Debug)]
28pub(crate) struct Io<T> {
29 id: Id,
30 io: T,
31 read_state: ReadState,
32 write_state: WriteState,
33 max_body_len: usize,
34}
35
36impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
37 pub(crate) fn new(id: Id, io: T, max_frame_body_len: usize) -> Self {
38 Io {
39 id,
40 io,
41 read_state: ReadState::Init,
42 write_state: WriteState::Init,
43 max_body_len: max_frame_body_len,
44 }
45 }
46}
47
48enum WriteState {
50 Init,
51 Header {
52 header: [u8; header::HEADER_SIZE],
53 buffer: Vec<u8>,
54 offset: usize,
55 },
56 Body {
57 buffer: Vec<u8>,
58 offset: usize,
59 },
60}
61
62impl fmt::Debug for WriteState {
63 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64 match self {
65 WriteState::Init => f.write_str("(WriteState::Init)"),
66 WriteState::Header { offset, .. } => {
67 write!(f, "(WriteState::Header (offset {}))", offset)
68 }
69 WriteState::Body { offset, buffer } => {
70 write!(
71 f,
72 "(WriteState::Body (offset {}) (buffer-len {}))",
73 offset,
74 buffer.len()
75 )
76 }
77 }
78 }
79}
80
81impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Frame<()>> for Io<T> {
82 type Error = io::Error;
83
84 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85 let this = Pin::into_inner(self);
86 loop {
87 tracing::trace!(target: LOG_TARGET, "{}: write: {:?}", this.id, this.write_state);
88 match &mut this.write_state {
89 WriteState::Init => return Poll::Ready(Ok(())),
90 WriteState::Header {
91 header,
92 buffer,
93 ref mut offset,
94 } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) {
95 Poll::Pending => return Poll::Pending,
96 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
97 Poll::Ready(Ok(n)) => {
98 if n == 0 {
99 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
100 }
101 *offset += n;
102 if *offset == header.len() {
103 if !buffer.is_empty() {
104 let buffer = std::mem::take(buffer);
105 this.write_state = WriteState::Body { buffer, offset: 0 };
106 } else {
107 this.write_state = WriteState::Init;
108 }
109 }
110 }
111 },
112 WriteState::Body {
113 buffer,
114 ref mut offset,
115 } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) {
116 Poll::Pending => return Poll::Pending,
117 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
118 Poll::Ready(Ok(n)) => {
119 if n == 0 {
120 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
121 }
122 *offset += n;
123 if *offset == buffer.len() {
124 this.write_state = WriteState::Init;
125 }
126 }
127 },
128 }
129 }
130 }
131
132 fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> {
133 let header = header::encode(&f.header);
134 let buffer = f.body;
135 self.get_mut().write_state = WriteState::Header {
136 header,
137 buffer,
138 offset: 0,
139 };
140 Ok(())
141 }
142
143 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
144 let this = Pin::into_inner(self);
145 ready!(this.poll_ready_unpin(cx))?;
146 Pin::new(&mut this.io).poll_flush(cx)
147 }
148
149 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150 let this = Pin::into_inner(self);
151 ready!(this.poll_ready_unpin(cx))?;
152 Pin::new(&mut this.io).poll_close(cx)
153 }
154}
155
156enum ReadState {
158 Init,
160 Header {
162 offset: usize,
163 buffer: [u8; header::HEADER_SIZE],
164 },
165 Body {
167 header: header::Header<()>,
168 offset: usize,
169 buffer: Vec<u8>,
170 },
171}
172
173impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
174 type Item = Result<Frame<()>, FrameDecodeError>;
175
176 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
177 let this = &mut *self;
178 loop {
179 tracing::trace!(target: LOG_TARGET, "{}: read: {:?}", this.id, this.read_state);
180 match this.read_state {
181 ReadState::Init => {
182 this.read_state = ReadState::Header {
183 offset: 0,
184 buffer: [0; header::HEADER_SIZE],
185 };
186 }
187 ReadState::Header {
188 ref mut offset,
189 ref mut buffer,
190 } => {
191 if *offset == header::HEADER_SIZE {
192 let header = match header::decode(buffer) {
193 Ok(hd) => hd,
194 Err(e) => return Poll::Ready(Some(Err(e.into()))),
195 };
196
197 tracing::trace!(target: LOG_TARGET, "{}: read: {}", this.id, header);
198
199 if header.tag() != header::Tag::Data {
200 this.read_state = ReadState::Init;
201 return Poll::Ready(Some(Ok(Frame::new(header))));
202 }
203
204 let body_len = header.len().val() as usize;
205
206 if body_len > this.max_body_len {
207 return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(
208 body_len,
209 ))));
210 }
211
212 this.read_state = ReadState::Body {
213 header,
214 offset: 0,
215 buffer: vec![0; body_len],
216 };
217
218 continue;
219 }
220
221 let buf = &mut buffer[*offset..header::HEADER_SIZE];
222 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
223 0 => {
224 if *offset == 0 {
225 return Poll::Ready(None);
226 }
227 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
228 return Poll::Ready(Some(Err(e)));
229 }
230 n => *offset += n,
231 }
232 }
233 ReadState::Body {
234 ref header,
235 ref mut offset,
236 ref mut buffer,
237 } => {
238 let body_len = header.len().val() as usize;
239
240 if *offset == body_len {
241 let h = header.clone();
242 let v = std::mem::take(buffer);
243 this.read_state = ReadState::Init;
244 return Poll::Ready(Some(Ok(Frame { header: h, body: v })));
245 }
246
247 let buf = &mut buffer[*offset..body_len];
248 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
249 0 => {
250 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
251 return Poll::Ready(Some(Err(e)));
252 }
253 n => *offset += n,
254 }
255 }
256 }
257 }
258 }
259}
260
261impl fmt::Debug for ReadState {
262 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
263 match self {
264 ReadState::Init => f.write_str("(ReadState::Init)"),
265 ReadState::Header { offset, .. } => {
266 write!(f, "(ReadState::Header (offset {}))", offset)
267 }
268 ReadState::Body {
269 header,
270 offset,
271 buffer,
272 } => {
273 write!(
274 f,
275 "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
276 header,
277 offset,
278 buffer.len()
279 )
280 }
281 }
282 }
283}
284
285#[non_exhaustive]
287#[derive(Debug)]
288pub enum FrameDecodeError {
289 Io(io::Error),
291 Header(HeaderDecodeError),
293 FrameTooLarge(usize),
295}
296
297impl PartialEq for FrameDecodeError {
298 fn eq(&self, other: &Self) -> bool {
299 match (self, other) {
300 (FrameDecodeError::Io(e1), FrameDecodeError::Io(e2)) => e1.kind() == e2.kind(),
301 (FrameDecodeError::Header(e1), FrameDecodeError::Header(e2)) => e1 == e2,
302 (FrameDecodeError::FrameTooLarge(n1), FrameDecodeError::FrameTooLarge(n2)) => n1 == n2,
303 _ => false,
304 }
305 }
306}
307
308impl std::fmt::Display for FrameDecodeError {
309 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
310 match self {
311 FrameDecodeError::Io(e) => write!(f, "i/o error: {}", e),
312 FrameDecodeError::Header(e) => write!(f, "decode error: {}", e),
313 FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n),
314 }
315 }
316}
317
318impl std::error::Error for FrameDecodeError {
319 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
320 match self {
321 FrameDecodeError::Io(e) => Some(e),
322 FrameDecodeError::Header(e) => Some(e),
323 FrameDecodeError::FrameTooLarge(_) => None,
324 }
325 }
326}
327
328impl From<std::io::Error> for FrameDecodeError {
329 fn from(e: std::io::Error) -> Self {
330 FrameDecodeError::Io(e)
331 }
332}
333
334impl From<HeaderDecodeError> for FrameDecodeError {
335 fn from(e: HeaderDecodeError) -> Self {
336 FrameDecodeError::Header(e)
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use quickcheck::{Arbitrary, Gen, QuickCheck};
344 use rand::RngCore;
345
346 impl Arbitrary for Frame<()> {
347 fn arbitrary(g: &mut Gen) -> Self {
348 let mut header: header::Header<()> = Arbitrary::arbitrary(g);
349 let body = if header.tag() == header::Tag::Data {
350 header.set_len(header.len().val() % 4096);
351 let mut b = vec![0; header.len().val() as usize];
352 rand::thread_rng().fill_bytes(&mut b);
353 b
354 } else {
355 Vec::new()
356 };
357 Frame { header, body }
358 }
359 }
360
361 #[test]
362 fn encode_decode_identity() {
363 fn property(f: Frame<()>) -> bool {
364 futures::executor::block_on(async move {
365 let id = crate::yamux::connection::Id::random();
366 let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len());
367 if io.send(f.clone()).await.is_err() {
368 return false;
369 }
370 if io.flush().await.is_err() {
371 return false;
372 }
373 io.io.set_position(0);
374 if let Ok(Some(x)) = io.try_next().await {
375 x == f
376 } else {
377 false
378 }
379 })
380 }
381
382 QuickCheck::new().tests(10_000).quickcheck(property as fn(Frame<()>) -> bool)
383 }
384}