litep2p/multistream_select/
length_delimited.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
22use futures::{io::IoSlice, prelude::*};
23use std::{
24    convert::TryFrom as _,
25    io,
26    pin::Pin,
27    task::{Context, Poll},
28};
29
30const MAX_LEN_BYTES: u16 = 2;
31const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
32const DEFAULT_BUFFER_SIZE: usize = 64;
33const LOG_TARGET: &str = "litep2p::multistream-select";
34
35/// A `Stream` and `Sink` for unsigned-varint length-delimited frames,
36/// wrapping an underlying `AsyncRead + AsyncWrite` I/O resource.
37///
38/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
39/// frame length). Frames mostly consist in a short protocol name, which is highly
40/// unlikely to be more than 16KiB long.
41#[pin_project::pin_project]
42#[derive(Debug)]
43pub struct LengthDelimited<R> {
44    /// The inner I/O resource.
45    #[pin]
46    inner: R,
47    /// Read buffer for a single incoming unsigned-varint length-delimited frame.
48    read_buffer: BytesMut,
49    /// Write buffer for outgoing unsigned-varint length-delimited frames.
50    write_buffer: BytesMut,
51    /// The current read state, alternating between reading a frame
52    /// length and reading a frame payload.
53    read_state: ReadState,
54}
55
56#[derive(Debug, Copy, Clone, PartialEq, Eq)]
57enum ReadState {
58    /// We are currently reading the length of the next frame of data.
59    ReadLength {
60        buf: [u8; MAX_LEN_BYTES as usize],
61        pos: usize,
62    },
63    /// We are currently reading the frame of data itself.
64    ReadData { len: u16, pos: usize },
65}
66
67impl Default for ReadState {
68    fn default() -> Self {
69        ReadState::ReadLength {
70            buf: [0; MAX_LEN_BYTES as usize],
71            pos: 0,
72        }
73    }
74}
75
76impl<R> LengthDelimited<R> {
77    /// Creates a new I/O resource for reading and writing unsigned-varint
78    /// length delimited frames.
79    pub fn new(inner: R) -> LengthDelimited<R> {
80        LengthDelimited {
81            inner,
82            read_state: ReadState::default(),
83            read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
84            write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
85        }
86    }
87
88    /// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream.
89    ///
90    /// # Panic
91    ///
92    /// Will panic if called while there is data in the read or write buffer.
93    /// The read buffer is guaranteed to be empty whenever `Stream::poll` yields
94    /// a new `Bytes` frame. The write buffer is guaranteed to be empty after
95    /// flushing.
96    pub fn into_inner(self) -> R {
97        assert!(self.read_buffer.is_empty());
98        assert!(self.write_buffer.is_empty());
99        self.inner
100    }
101
102    /// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the
103    /// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying
104    /// I/O stream.
105    ///
106    /// This is typically done if further uvi-framed messages are expected to be
107    /// received but no more such messages are written, allowing the writing of
108    /// follow-up protocol data to commence.
109    pub fn into_reader(self) -> LengthDelimitedReader<R> {
110        LengthDelimitedReader { inner: self }
111    }
112
113    /// Writes all buffered frame data to the underlying I/O stream,
114    /// _without flushing it_.
115    ///
116    /// After this method returns `Poll::Ready`, the write buffer of frames
117    /// submitted to the `Sink` is guaranteed to be empty.
118    pub fn poll_write_buffer(
119        self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121    ) -> Poll<Result<(), io::Error>>
122    where
123        R: AsyncWrite,
124    {
125        let mut this = self.project();
126
127        while !this.write_buffer.is_empty() {
128            match this.inner.as_mut().poll_write(cx, this.write_buffer) {
129                Poll::Pending => return Poll::Pending,
130                Poll::Ready(Ok(0)) =>
131                    return Poll::Ready(Err(io::Error::new(
132                        io::ErrorKind::WriteZero,
133                        "Failed to write buffered frame.",
134                    ))),
135                Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
136                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
137            }
138        }
139
140        Poll::Ready(Ok(()))
141    }
142}
143
144impl<R> Stream for LengthDelimited<R>
145where
146    R: AsyncRead,
147{
148    type Item = Result<Bytes, io::Error>;
149
150    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
151        let mut this = self.project();
152
153        loop {
154            match this.read_state {
155                ReadState::ReadLength { buf, pos } => {
156                    match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) {
157                        Poll::Ready(Ok(0)) =>
158                            if *pos == 0 {
159                                return Poll::Ready(None);
160                            } else {
161                                return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
162                            },
163                        Poll::Ready(Ok(n)) => {
164                            debug_assert_eq!(n, 1);
165                            *pos += n;
166                        }
167                        Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
168                        Poll::Pending => return Poll::Pending,
169                    };
170
171                    if (buf[*pos - 1] & 0x80) == 0 {
172                        // MSB is not set, indicating the end of the length prefix.
173                        let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| {
174                            tracing::debug!(target: LOG_TARGET, "invalid length prefix: {}", e);
175                            io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
176                        })?;
177
178                        if len >= 1 {
179                            *this.read_state = ReadState::ReadData { len, pos: 0 };
180                            this.read_buffer.resize(len as usize, 0);
181                        } else {
182                            debug_assert_eq!(len, 0);
183                            *this.read_state = ReadState::default();
184                            return Poll::Ready(Some(Ok(Bytes::new())));
185                        }
186                    } else if *pos == MAX_LEN_BYTES as usize {
187                        // MSB signals more length bytes but we have already read the maximum.
188                        // See the module documentation about the max frame len.
189                        return Poll::Ready(Some(Err(io::Error::new(
190                            io::ErrorKind::InvalidData,
191                            "Maximum frame length exceeded",
192                        ))));
193                    }
194                }
195                ReadState::ReadData { len, pos } => {
196                    match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
197                        Poll::Ready(Ok(0)) =>
198                            return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
199                        Poll::Ready(Ok(n)) => *pos += n,
200                        Poll::Pending => return Poll::Pending,
201                        Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
202                    };
203
204                    if *pos == *len as usize {
205                        // Finished reading the frame.
206                        let frame = this.read_buffer.split_off(0).freeze();
207                        *this.read_state = ReadState::default();
208                        return Poll::Ready(Some(Ok(frame)));
209                    }
210                }
211            }
212        }
213    }
214}
215
216impl<R> Sink<Bytes> for LengthDelimited<R>
217where
218    R: AsyncWrite,
219{
220    type Error = io::Error;
221
222    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
223        // Use the maximum frame length also as a (soft) upper limit
224        // for the entire write buffer. The actual (hard) limit is thus
225        // implied to be roughly 2 * MAX_FRAME_SIZE.
226        if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
227            match self.as_mut().poll_write_buffer(cx) {
228                Poll::Ready(Ok(())) => {}
229                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
230                Poll::Pending => return Poll::Pending,
231            }
232
233            debug_assert!(self.as_mut().project().write_buffer.is_empty());
234        }
235
236        Poll::Ready(Ok(()))
237    }
238
239    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
240        let this = self.project();
241
242        let len = match u16::try_from(item.len()) {
243            Ok(len) if len <= MAX_FRAME_SIZE => len,
244            _ =>
245                return Err(io::Error::new(
246                    io::ErrorKind::InvalidData,
247                    "Maximum frame size exceeded.",
248                )),
249        };
250
251        let mut uvi_buf = unsigned_varint::encode::u16_buffer();
252        let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
253        this.write_buffer.reserve(len as usize + uvi_len.len());
254        this.write_buffer.put(uvi_len);
255        this.write_buffer.put(item);
256
257        Ok(())
258    }
259
260    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
261        // Write all buffered frame data to the underlying I/O stream.
262        match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
263            Poll::Ready(Ok(())) => {}
264            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
265            Poll::Pending => return Poll::Pending,
266        }
267
268        let this = self.project();
269        debug_assert!(this.write_buffer.is_empty());
270
271        // Flush the underlying I/O stream.
272        this.inner.poll_flush(cx)
273    }
274
275    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276        // Write all buffered frame data to the underlying I/O stream.
277        match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
278            Poll::Ready(Ok(())) => {}
279            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
280            Poll::Pending => return Poll::Pending,
281        }
282
283        let this = self.project();
284        debug_assert!(this.write_buffer.is_empty());
285
286        // Close the underlying I/O stream.
287        this.inner.poll_close(cx)
288    }
289}
290
291/// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited
292/// frames on an underlying I/O resource combined with direct `AsyncWrite` access.
293#[pin_project::pin_project]
294#[derive(Debug)]
295pub struct LengthDelimitedReader<R> {
296    #[pin]
297    inner: LengthDelimited<R>,
298}
299
300impl<R> LengthDelimitedReader<R> {
301    /// Destroys the `LengthDelimitedReader` and returns the underlying I/O stream.
302    ///
303    /// This method is guaranteed not to drop any data read from or not yet
304    /// submitted to the underlying I/O stream.
305    ///
306    /// # Panic
307    ///
308    /// Will panic if called while there is data in the read or write buffer.
309    /// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`]
310    /// yield a new `Message`. The write buffer is guaranteed to be empty whenever
311    /// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after
312    /// the [`Sink`] has been completely flushed via [`Sink::poll_flush`].
313    pub fn into_inner(self) -> R {
314        self.inner.into_inner()
315    }
316}
317
318impl<R> Stream for LengthDelimitedReader<R>
319where
320    R: AsyncRead,
321{
322    type Item = Result<Bytes, io::Error>;
323
324    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
325        self.project().inner.poll_next(cx)
326    }
327}
328
329impl<R> AsyncWrite for LengthDelimitedReader<R>
330where
331    R: AsyncWrite,
332{
333    fn poll_write(
334        self: Pin<&mut Self>,
335        cx: &mut Context<'_>,
336        buf: &[u8],
337    ) -> Poll<Result<usize, io::Error>> {
338        // `this` here designates the `LengthDelimited`.
339        let mut this = self.project().inner;
340
341        // We need to flush any data previously written with the `LengthDelimited`.
342        match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
343            Poll::Ready(Ok(())) => {}
344            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
345            Poll::Pending => return Poll::Pending,
346        }
347        debug_assert!(this.write_buffer.is_empty());
348
349        this.project().inner.poll_write(cx, buf)
350    }
351
352    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
353        self.project().inner.poll_flush(cx)
354    }
355
356    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
357        self.project().inner.poll_close(cx)
358    }
359
360    fn poll_write_vectored(
361        self: Pin<&mut Self>,
362        cx: &mut Context<'_>,
363        bufs: &[IoSlice<'_>],
364    ) -> Poll<Result<usize, io::Error>> {
365        // `this` here designates the `LengthDelimited`.
366        let mut this = self.project().inner;
367
368        // We need to flush any data previously written with the `LengthDelimited`.
369        match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
370            Poll::Ready(Ok(())) => {}
371            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
372            Poll::Pending => return Poll::Pending,
373        }
374        debug_assert!(this.write_buffer.is_empty());
375
376        this.project().inner.poll_write_vectored(cx, bufs)
377    }
378}