litep2p/multistream_select/
protocol.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Multistream-select protocol messages an I/O operations for
22//! constructing protocol negotiation flows.
23//!
24//! A protocol negotiation flow is constructed by using the
25//! `Stream` and `Sink` implementations of `MessageIO` and
26//! `MessageReader`.
27
28use crate::{
29    codec::unsigned_varint::UnsignedVarint,
30    error::Error as Litep2pError,
31    multistream_select::{
32        length_delimited::{LengthDelimited, LengthDelimitedReader},
33        Version,
34    },
35};
36
37use bytes::{BufMut, Bytes, BytesMut};
38use futures::{io::IoSlice, prelude::*, ready};
39use std::{
40    convert::TryFrom,
41    error::Error,
42    fmt, io,
43    pin::Pin,
44    task::{Context, Poll},
45};
46use unsigned_varint as uvi;
47
48/// The maximum number of supported protocols that can be processed.
49const MAX_PROTOCOLS: usize = 1000;
50
51/// The encoded form of a multistream-select 1.0.0 header message.
52pub const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
53/// The encoded form of a multistream-select 'na' message.
54const MSG_PROTOCOL_NA: &[u8] = b"na\n";
55/// The encoded form of a multistream-select 'ls' message.
56const MSG_LS: &[u8] = b"ls\n";
57/// Logging target.
58const LOG_TARGET: &str = "litep2p::multistream-select";
59
60/// The multistream-select header lines preceeding negotiation.
61///
62/// Every [`Version`] has a corresponding header line.
63#[derive(Copy, Clone, Debug, PartialEq, Eq)]
64pub enum HeaderLine {
65    /// The `/multistream/1.0.0` header line.
66    V1,
67}
68
69impl From<Version> for HeaderLine {
70    fn from(v: Version) -> HeaderLine {
71        match v {
72            Version::V1 | Version::V1Lazy => HeaderLine::V1,
73        }
74    }
75}
76
77/// A protocol (name) exchanged during protocol negotiation.
78#[derive(Clone, Debug, PartialEq, Eq)]
79pub struct Protocol(Bytes);
80
81impl AsRef<[u8]> for Protocol {
82    fn as_ref(&self) -> &[u8] {
83        self.0.as_ref()
84    }
85}
86
87impl TryFrom<Bytes> for Protocol {
88    type Error = ProtocolError;
89
90    fn try_from(value: Bytes) -> Result<Self, Self::Error> {
91        if !value.as_ref().starts_with(b"/") {
92            return Err(ProtocolError::InvalidProtocol);
93        }
94        Ok(Protocol(value))
95    }
96}
97
98impl TryFrom<&[u8]> for Protocol {
99    type Error = ProtocolError;
100
101    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
102        Self::try_from(Bytes::copy_from_slice(value))
103    }
104}
105
106impl fmt::Display for Protocol {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        write!(f, "{}", String::from_utf8_lossy(&self.0))
109    }
110}
111
112/// A multistream-select protocol message.
113///
114/// Multistream-select protocol messages are exchanged with the goal
115/// of agreeing on a application-layer protocol to use on an I/O stream.
116#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum Message {
118    /// A header message identifies the multistream-select protocol
119    /// that the sender wishes to speak.
120    Header(HeaderLine),
121    /// A protocol message identifies a protocol request or acknowledgement.
122    Protocol(Protocol),
123    /// A message through which a peer requests the complete list of
124    /// supported protocols from the remote.
125    ListProtocols,
126    /// A message listing all supported protocols of a peer.
127    Protocols(Vec<Protocol>),
128    /// A message signaling that a requested protocol is not available.
129    NotAvailable,
130}
131
132impl Message {
133    /// Encodes a `Message` into its byte representation.
134    pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> {
135        match self {
136            Message::Header(HeaderLine::V1) => {
137                dest.reserve(MSG_MULTISTREAM_1_0.len());
138                dest.put(MSG_MULTISTREAM_1_0);
139                Ok(())
140            }
141            Message::Protocol(p) => {
142                let len = p.0.as_ref().len() + 1; // + 1 for \n
143                dest.reserve(len);
144                dest.put(p.0.as_ref());
145                dest.put_u8(b'\n');
146                Ok(())
147            }
148            Message::ListProtocols => {
149                dest.reserve(MSG_LS.len());
150                dest.put(MSG_LS);
151                Ok(())
152            }
153            Message::Protocols(ps) => {
154                let mut buf = uvi::encode::usize_buffer();
155                let mut encoded = Vec::with_capacity(ps.len());
156                for p in ps {
157                    encoded.extend(uvi::encode::usize(p.0.as_ref().len() + 1, &mut buf)); // +1 for '\n'
158                    encoded.extend_from_slice(p.0.as_ref());
159                    encoded.push(b'\n')
160                }
161                encoded.push(b'\n');
162                dest.reserve(encoded.len());
163                dest.put(encoded.as_ref());
164                Ok(())
165            }
166            Message::NotAvailable => {
167                dest.reserve(MSG_PROTOCOL_NA.len());
168                dest.put(MSG_PROTOCOL_NA);
169                Ok(())
170            }
171        }
172    }
173
174    /// Decodes a `Message` from its byte representation.
175    pub fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
176        if msg == MSG_MULTISTREAM_1_0 {
177            return Ok(Message::Header(HeaderLine::V1));
178        }
179
180        if msg == MSG_PROTOCOL_NA {
181            return Ok(Message::NotAvailable);
182        }
183
184        if msg == MSG_LS {
185            return Ok(Message::ListProtocols);
186        }
187
188        // If it starts with a `/`, ends with a line feed without any
189        // other line feeds in-between, it must be a protocol name.
190        if msg.first() == Some(&b'/')
191            && msg.last() == Some(&b'\n')
192            && !msg[..msg.len() - 1].contains(&b'\n')
193        {
194            let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
195            return Ok(Message::Protocol(p));
196        }
197
198        // At this point, it must be an `ls` response, i.e. one or more
199        // length-prefixed, newline-delimited protocol names.
200        let mut protocols = Vec::new();
201        let mut remaining: &[u8] = &msg;
202        loop {
203            // A well-formed message must be terminated with a newline.
204            // TODO: don't do this
205            if remaining == [b'\n'] || remaining.is_empty() {
206                break;
207            } else if protocols.len() == MAX_PROTOCOLS {
208                return Err(ProtocolError::TooManyProtocols);
209            }
210
211            // Decode the length of the next protocol name and check that
212            // it ends with a line feed.
213            let (len, tail) = uvi::decode::usize(remaining)?;
214            if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
215                return Err(ProtocolError::InvalidMessage);
216            }
217
218            // Parse the protocol name.
219            let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?;
220            protocols.push(p);
221
222            // Skip ahead to the next protocol.
223            remaining = &tail[len..];
224        }
225
226        Ok(Message::Protocols(protocols))
227    }
228}
229
230/// Create `multistream-select` message from an iterator of `Message`s.
231pub fn encode_multistream_message(
232    messages: impl IntoIterator<Item = Message>,
233) -> crate::Result<BytesMut> {
234    // encode `/multistream-select/1.0.0` header
235    let mut bytes = BytesMut::with_capacity(32);
236    let message = Message::Header(HeaderLine::V1);
237    message.encode(&mut bytes).map_err(|_| Litep2pError::InvalidData)?;
238    let mut header = UnsignedVarint::encode(bytes)?;
239
240    // encode each message
241    for message in messages {
242        let mut proto_bytes = BytesMut::with_capacity(256);
243        message.encode(&mut proto_bytes).map_err(|_| Litep2pError::InvalidData)?;
244        let mut proto_bytes = UnsignedVarint::encode(proto_bytes)?;
245        header.append(&mut proto_bytes);
246    }
247
248    Ok(BytesMut::from(&header[..]))
249}
250
251/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s.
252#[pin_project::pin_project]
253pub struct MessageIO<R> {
254    #[pin]
255    inner: LengthDelimited<R>,
256}
257
258impl<R> MessageIO<R> {
259    /// Constructs a new `MessageIO` resource wrapping the given I/O stream.
260    pub fn new(inner: R) -> MessageIO<R>
261    where
262        R: AsyncRead + AsyncWrite,
263    {
264        Self {
265            inner: LengthDelimited::new(inner),
266        }
267    }
268
269    /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the
270    /// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access
271    /// to the underlying I/O stream.
272    ///
273    /// This is typically done if further negotiation messages are expected to be
274    /// received but no more messages are written, allowing the writing of
275    /// follow-up protocol data to commence.
276    pub fn into_reader(self) -> MessageReader<R> {
277        MessageReader {
278            inner: self.inner.into_reader(),
279        }
280    }
281
282    /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream.
283    ///
284    /// # Panics
285    ///
286    /// Panics if the read buffer or write buffer is not empty, meaning that an incoming
287    /// protocol negotiation frame has been partially read or an outgoing frame
288    /// has not yet been flushed. The read buffer is guaranteed to be empty whenever
289    /// `MessageIO::poll` returned a message. The write buffer is guaranteed to be empty
290    /// when the sink has been flushed.
291    pub fn into_inner(self) -> R {
292        self.inner.into_inner()
293    }
294}
295
296impl<R> Sink<Message> for MessageIO<R>
297where
298    R: AsyncWrite,
299{
300    type Error = ProtocolError;
301
302    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303        self.project().inner.poll_ready(cx).map_err(From::from)
304    }
305
306    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
307        let mut buf = BytesMut::new();
308        item.encode(&mut buf)?;
309        self.project().inner.start_send(buf.freeze()).map_err(From::from)
310    }
311
312    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
313        self.project().inner.poll_flush(cx).map_err(From::from)
314    }
315
316    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
317        self.project().inner.poll_close(cx).map_err(From::from)
318    }
319}
320
321impl<R> Stream for MessageIO<R>
322where
323    R: AsyncRead,
324{
325    type Item = Result<Message, ProtocolError>;
326
327    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
328        match poll_stream(self.project().inner, cx) {
329            Poll::Pending => Poll::Pending,
330            Poll::Ready(None) => Poll::Ready(None),
331            Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
332            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
333        }
334    }
335}
336
337/// A `MessageReader` implements a `Stream` of `Message`s on an underlying
338/// I/O resource combined with direct `AsyncWrite` access.
339#[pin_project::pin_project]
340#[derive(Debug)]
341pub struct MessageReader<R> {
342    #[pin]
343    inner: LengthDelimitedReader<R>,
344}
345
346impl<R> MessageReader<R> {
347    /// Drops the `MessageReader` resource, yielding the underlying I/O stream
348    /// together with the remaining write buffer containing the protocol
349    /// negotiation frame data that has not yet been written to the I/O stream.
350    ///
351    /// # Panics
352    ///
353    /// Panics if the read buffer or write buffer is not empty, meaning that either
354    /// an incoming protocol negotiation frame has been partially read, or an
355    /// outgoing frame has not yet been flushed. The read buffer is guaranteed to
356    /// be empty whenever `MessageReader::poll` returned a message. The write
357    /// buffer is guaranteed to be empty whenever the sink has been flushed.
358    pub fn into_inner(self) -> R {
359        self.inner.into_inner()
360    }
361}
362
363impl<R> Stream for MessageReader<R>
364where
365    R: AsyncRead,
366{
367    type Item = Result<Message, ProtocolError>;
368
369    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
370        poll_stream(self.project().inner, cx)
371    }
372}
373
374impl<TInner> AsyncWrite for MessageReader<TInner>
375where
376    TInner: AsyncWrite,
377{
378    fn poll_write(
379        self: Pin<&mut Self>,
380        cx: &mut Context<'_>,
381        buf: &[u8],
382    ) -> Poll<Result<usize, io::Error>> {
383        self.project().inner.poll_write(cx, buf)
384    }
385
386    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
387        self.project().inner.poll_flush(cx)
388    }
389
390    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
391        self.project().inner.poll_close(cx)
392    }
393
394    fn poll_write_vectored(
395        self: Pin<&mut Self>,
396        cx: &mut Context<'_>,
397        bufs: &[IoSlice<'_>],
398    ) -> Poll<Result<usize, io::Error>> {
399        self.project().inner.poll_write_vectored(cx, bufs)
400    }
401}
402
403fn poll_stream<S>(
404    stream: Pin<&mut S>,
405    cx: &mut Context<'_>,
406) -> Poll<Option<Result<Message, ProtocolError>>>
407where
408    S: Stream<Item = Result<Bytes, io::Error>>,
409{
410    let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
411        match Message::decode(msg) {
412            Ok(m) => m,
413            Err(err) => return Poll::Ready(Some(Err(err))),
414        }
415    } else {
416        return Poll::Ready(None);
417    };
418
419    tracing::trace!(target: LOG_TARGET, "Received message: {:?}", msg);
420
421    Poll::Ready(Some(Ok(msg)))
422}
423
424/// A protocol error.
425#[derive(Debug, thiserror::Error)]
426pub enum ProtocolError {
427    /// I/O error.
428    #[error("I/O error: `{0}`")]
429    IoError(#[from] io::Error),
430
431    /// Received an invalid message from the remote.
432    #[error("Received an invalid message from the remote.")]
433    InvalidMessage,
434
435    /// A protocol (name) is invalid.
436    #[error("A protocol (name) is invalid.")]
437    InvalidProtocol,
438
439    /// Too many protocols have been returned by the remote.
440    #[error("Too many protocols have been returned by the remote.")]
441    TooManyProtocols,
442
443    /// The protocol is not supported.
444    #[error("The protocol is not supported.")]
445    ProtocolNotSupported,
446}
447
448impl PartialEq for ProtocolError {
449    fn eq(&self, other: &Self) -> bool {
450        match (self, other) {
451            (ProtocolError::IoError(lhs), ProtocolError::IoError(rhs)) => lhs.kind() == rhs.kind(),
452            _ => std::mem::discriminant(self) == std::mem::discriminant(other),
453        }
454    }
455}
456
457impl From<ProtocolError> for io::Error {
458    fn from(err: ProtocolError) -> Self {
459        if let ProtocolError::IoError(e) = err {
460            return e;
461        }
462        io::ErrorKind::InvalidData.into()
463    }
464}
465
466impl From<uvi::decode::Error> for ProtocolError {
467    fn from(err: uvi::decode::Error) -> ProtocolError {
468        Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
469    }
470}