litep2p/multistream_select/
length_delimited.rs1use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
22use futures::{io::IoSlice, prelude::*};
23use std::{
24 convert::TryFrom as _,
25 io,
26 pin::Pin,
27 task::{Context, Poll},
28};
29
30const MAX_LEN_BYTES: u16 = 2;
31const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
32const DEFAULT_BUFFER_SIZE: usize = 64;
33const LOG_TARGET: &str = "litep2p::multistream-select";
34
35#[pin_project::pin_project]
42#[derive(Debug)]
43pub struct LengthDelimited<R> {
44 #[pin]
46 inner: R,
47 read_buffer: BytesMut,
49 write_buffer: BytesMut,
51 read_state: ReadState,
54}
55
56#[derive(Debug, Copy, Clone, PartialEq, Eq)]
57enum ReadState {
58 ReadLength {
60 buf: [u8; MAX_LEN_BYTES as usize],
61 pos: usize,
62 },
63 ReadData { len: u16, pos: usize },
65}
66
67impl Default for ReadState {
68 fn default() -> Self {
69 ReadState::ReadLength {
70 buf: [0; MAX_LEN_BYTES as usize],
71 pos: 0,
72 }
73 }
74}
75
76impl<R> LengthDelimited<R> {
77 pub fn new(inner: R) -> LengthDelimited<R> {
80 LengthDelimited {
81 inner,
82 read_state: ReadState::default(),
83 read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
84 write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
85 }
86 }
87
88 pub fn into_inner(self) -> R {
97 assert!(self.read_buffer.is_empty());
98 assert!(self.write_buffer.is_empty());
99 self.inner
100 }
101
102 pub fn into_reader(self) -> LengthDelimitedReader<R> {
110 LengthDelimitedReader { inner: self }
111 }
112
113 pub fn poll_write_buffer(
119 self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 ) -> Poll<Result<(), io::Error>>
122 where
123 R: AsyncWrite,
124 {
125 let mut this = self.project();
126
127 while !this.write_buffer.is_empty() {
128 match this.inner.as_mut().poll_write(cx, this.write_buffer) {
129 Poll::Pending => return Poll::Pending,
130 Poll::Ready(Ok(0)) =>
131 return Poll::Ready(Err(io::Error::new(
132 io::ErrorKind::WriteZero,
133 "Failed to write buffered frame.",
134 ))),
135 Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
136 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
137 }
138 }
139
140 Poll::Ready(Ok(()))
141 }
142}
143
144impl<R> Stream for LengthDelimited<R>
145where
146 R: AsyncRead,
147{
148 type Item = Result<Bytes, io::Error>;
149
150 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
151 let mut this = self.project();
152
153 loop {
154 match this.read_state {
155 ReadState::ReadLength { buf, pos } => {
156 match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) {
157 Poll::Ready(Ok(0)) =>
158 if *pos == 0 {
159 return Poll::Ready(None);
160 } else {
161 return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
162 },
163 Poll::Ready(Ok(n)) => {
164 debug_assert_eq!(n, 1);
165 *pos += n;
166 }
167 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
168 Poll::Pending => return Poll::Pending,
169 };
170
171 if (buf[*pos - 1] & 0x80) == 0 {
172 let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| {
174 tracing::debug!(target: LOG_TARGET, "invalid length prefix: {}", e);
175 io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
176 })?;
177
178 if len >= 1 {
179 *this.read_state = ReadState::ReadData { len, pos: 0 };
180 this.read_buffer.resize(len as usize, 0);
181 } else {
182 debug_assert_eq!(len, 0);
183 *this.read_state = ReadState::default();
184 return Poll::Ready(Some(Ok(Bytes::new())));
185 }
186 } else if *pos == MAX_LEN_BYTES as usize {
187 return Poll::Ready(Some(Err(io::Error::new(
190 io::ErrorKind::InvalidData,
191 "Maximum frame length exceeded",
192 ))));
193 }
194 }
195 ReadState::ReadData { len, pos } => {
196 match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
197 Poll::Ready(Ok(0)) =>
198 return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
199 Poll::Ready(Ok(n)) => *pos += n,
200 Poll::Pending => return Poll::Pending,
201 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
202 };
203
204 if *pos == *len as usize {
205 let frame = this.read_buffer.split_off(0).freeze();
207 *this.read_state = ReadState::default();
208 return Poll::Ready(Some(Ok(frame)));
209 }
210 }
211 }
212 }
213 }
214}
215
216impl<R> Sink<Bytes> for LengthDelimited<R>
217where
218 R: AsyncWrite,
219{
220 type Error = io::Error;
221
222 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
223 if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
227 match self.as_mut().poll_write_buffer(cx) {
228 Poll::Ready(Ok(())) => {}
229 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
230 Poll::Pending => return Poll::Pending,
231 }
232
233 debug_assert!(self.as_mut().project().write_buffer.is_empty());
234 }
235
236 Poll::Ready(Ok(()))
237 }
238
239 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
240 let this = self.project();
241
242 let len = match u16::try_from(item.len()) {
243 Ok(len) if len <= MAX_FRAME_SIZE => len,
244 _ =>
245 return Err(io::Error::new(
246 io::ErrorKind::InvalidData,
247 "Maximum frame size exceeded.",
248 )),
249 };
250
251 let mut uvi_buf = unsigned_varint::encode::u16_buffer();
252 let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
253 this.write_buffer.reserve(len as usize + uvi_len.len());
254 this.write_buffer.put(uvi_len);
255 this.write_buffer.put(item);
256
257 Ok(())
258 }
259
260 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
261 match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
263 Poll::Ready(Ok(())) => {}
264 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
265 Poll::Pending => return Poll::Pending,
266 }
267
268 let this = self.project();
269 debug_assert!(this.write_buffer.is_empty());
270
271 this.inner.poll_flush(cx)
273 }
274
275 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276 match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
278 Poll::Ready(Ok(())) => {}
279 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
280 Poll::Pending => return Poll::Pending,
281 }
282
283 let this = self.project();
284 debug_assert!(this.write_buffer.is_empty());
285
286 this.inner.poll_close(cx)
288 }
289}
290
291#[pin_project::pin_project]
294#[derive(Debug)]
295pub struct LengthDelimitedReader<R> {
296 #[pin]
297 inner: LengthDelimited<R>,
298}
299
300impl<R> LengthDelimitedReader<R> {
301 pub fn into_inner(self) -> R {
314 self.inner.into_inner()
315 }
316}
317
318impl<R> Stream for LengthDelimitedReader<R>
319where
320 R: AsyncRead,
321{
322 type Item = Result<Bytes, io::Error>;
323
324 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
325 self.project().inner.poll_next(cx)
326 }
327}
328
329impl<R> AsyncWrite for LengthDelimitedReader<R>
330where
331 R: AsyncWrite,
332{
333 fn poll_write(
334 self: Pin<&mut Self>,
335 cx: &mut Context<'_>,
336 buf: &[u8],
337 ) -> Poll<Result<usize, io::Error>> {
338 let mut this = self.project().inner;
340
341 match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
343 Poll::Ready(Ok(())) => {}
344 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
345 Poll::Pending => return Poll::Pending,
346 }
347 debug_assert!(this.write_buffer.is_empty());
348
349 this.project().inner.poll_write(cx, buf)
350 }
351
352 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
353 self.project().inner.poll_flush(cx)
354 }
355
356 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
357 self.project().inner.poll_close(cx)
358 }
359
360 fn poll_write_vectored(
361 self: Pin<&mut Self>,
362 cx: &mut Context<'_>,
363 bufs: &[IoSlice<'_>],
364 ) -> Poll<Result<usize, io::Error>> {
365 let mut this = self.project().inner;
367
368 match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
370 Poll::Ready(Ok(())) => {}
371 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
372 Poll::Pending => return Poll::Pending,
373 }
374 debug_assert!(this.write_buffer.is_empty());
375
376 this.project().inner.poll_write_vectored(cx, bufs)
377 }
378}