litep2p/multistream_select/
listener_select.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Protocol negotiation strategies for the peer acting as the listener
22//! in a multistream-select protocol negotiation.
23
24use crate::{
25    codec::unsigned_varint::UnsignedVarint,
26    error::{self, Error},
27    multistream_select::{
28        protocol::{
29            encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
30        },
31        Negotiated, NegotiationError,
32    },
33    types::protocol::ProtocolName,
34};
35
36use bytes::{Bytes, BytesMut};
37use futures::prelude::*;
38use smallvec::SmallVec;
39use std::{
40    convert::TryFrom as _,
41    iter::FromIterator,
42    mem,
43    pin::Pin,
44    task::{Context, Poll},
45};
46
47const LOG_TARGET: &str = "litep2p::multistream-select";
48
49/// Returns a `Future` that negotiates a protocol on the given I/O stream
50/// for a peer acting as the _listener_ (or _responder_).
51///
52/// This function is given an I/O stream and a list of protocols and returns a
53/// computation that performs the protocol negotiation with the remote. The
54/// returned `Future` resolves with the name of the negotiated protocol and
55/// a [`Negotiated`] I/O stream.
56pub fn listener_select_proto<R, I>(inner: R, protocols: I) -> ListenerSelectFuture<R, I::Item>
57where
58    R: AsyncRead + AsyncWrite,
59    I: IntoIterator,
60    I::Item: AsRef<[u8]>,
61{
62    let protocols = protocols.into_iter().filter_map(|n| match Protocol::try_from(n.as_ref()) {
63        Ok(p) => Some((n, p)),
64        Err(e) => {
65            tracing::warn!(
66                target: LOG_TARGET,
67                "Listener: Ignoring invalid protocol: {} due to {}",
68                String::from_utf8_lossy(n.as_ref()),
69                e
70            );
71            None
72        }
73    });
74    ListenerSelectFuture {
75        protocols: SmallVec::from_iter(protocols),
76        state: State::RecvHeader {
77            io: MessageIO::new(inner),
78        },
79        last_sent_na: false,
80    }
81}
82
83/// The `Future` returned by [`listener_select_proto`] that performs a
84/// multistream-select protocol negotiation on an underlying I/O stream.
85#[pin_project::pin_project]
86pub struct ListenerSelectFuture<R, N> {
87    // TODO: It would be nice if eventually N = Protocol, which has a
88    // few more implications on the API.
89    protocols: SmallVec<[(N, Protocol); 8]>,
90    state: State<R, N>,
91    /// Whether the last message sent was a protocol rejection (i.e. `na\n`).
92    ///
93    /// If the listener reads garbage or EOF after such a rejection,
94    /// the dialer is likely using `V1Lazy` and negotiation must be
95    /// considered failed, but not with a protocol violation or I/O
96    /// error.
97    last_sent_na: bool,
98}
99
100enum State<R, N> {
101    RecvHeader {
102        io: MessageIO<R>,
103    },
104    SendHeader {
105        io: MessageIO<R>,
106    },
107    RecvMessage {
108        io: MessageIO<R>,
109    },
110    SendMessage {
111        io: MessageIO<R>,
112        message: Message,
113        protocol: Option<N>,
114    },
115    Flush {
116        io: MessageIO<R>,
117        protocol: Option<N>,
118    },
119    Done,
120}
121
122impl<R, N> Future for ListenerSelectFuture<R, N>
123where
124    // The Unpin bound here is required because we
125    // produce a `Negotiated<R>` as the output.
126    // It also makes the implementation considerably
127    // easier to write.
128    R: AsyncRead + AsyncWrite + Unpin,
129    N: AsRef<[u8]> + Clone,
130{
131    type Output = Result<(N, Negotiated<R>), NegotiationError>;
132
133    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
134        let this = self.project();
135
136        loop {
137            match mem::replace(this.state, State::Done) {
138                State::RecvHeader { mut io } => {
139                    match io.poll_next_unpin(cx) {
140                        Poll::Ready(Some(Ok(Message::Header(h)))) => match h {
141                            HeaderLine::V1 => *this.state = State::SendHeader { io },
142                        },
143                        Poll::Ready(Some(Ok(_))) =>
144                            return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
145                        Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
146                        // Treat EOF error as [`NegotiationError::Failed`], not as
147                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
148                        // stream as a permissible way to "gracefully" fail a negotiation.
149                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
150                        Poll::Pending => {
151                            *this.state = State::RecvHeader { io };
152                            return Poll::Pending;
153                        }
154                    }
155                }
156
157                State::SendHeader { mut io } => {
158                    match Pin::new(&mut io).poll_ready(cx) {
159                        Poll::Pending => {
160                            *this.state = State::SendHeader { io };
161                            return Poll::Pending;
162                        }
163                        Poll::Ready(Ok(())) => {}
164                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
165                    }
166
167                    let msg = Message::Header(HeaderLine::V1);
168                    if let Err(err) = Pin::new(&mut io).start_send(msg) {
169                        return Poll::Ready(Err(From::from(err)));
170                    }
171
172                    *this.state = State::Flush { io, protocol: None };
173                }
174
175                State::RecvMessage { mut io } => {
176                    let msg = match Pin::new(&mut io).poll_next(cx) {
177                        Poll::Ready(Some(Ok(msg))) => msg,
178                        // Treat EOF error as [`NegotiationError::Failed`], not as
179                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
180                        // stream as a permissible way to "gracefully" fail a negotiation.
181                        //
182                        // This is e.g. important when a listener rejects a protocol with
183                        // [`Message::NotAvailable`] and the dialer does not have alternative
184                        // protocols to propose. Then the dialer will stop the negotiation and drop
185                        // the corresponding stream. As a listener this EOF should be interpreted as
186                        // a failed negotiation.
187                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
188                        Poll::Pending => {
189                            *this.state = State::RecvMessage { io };
190                            return Poll::Pending;
191                        }
192                        Poll::Ready(Some(Err(err))) => {
193                            if *this.last_sent_na {
194                                // When we read garbage or EOF after having already rejected a
195                                // protocol, the dialer is most likely using `V1Lazy` and has
196                                // optimistically settled on this protocol, so this is really a
197                                // failed negotiation, not a protocol violation. In this case
198                                // the dialer also raises `NegotiationError::Failed` when finally
199                                // reading the `N/A` response.
200                                if let ProtocolError::InvalidMessage = &err {
201                                    tracing::trace!(
202                                        target: LOG_TARGET,
203                                        "Listener: Negotiation failed with invalid \
204                                        message after protocol rejection."
205                                    );
206                                    return Poll::Ready(Err(NegotiationError::Failed));
207                                }
208                                if let ProtocolError::IoError(e) = &err {
209                                    if e.kind() == std::io::ErrorKind::UnexpectedEof {
210                                        tracing::trace!(
211                                            target: LOG_TARGET,
212                                            "Listener: Negotiation failed with EOF \
213                                            after protocol rejection."
214                                        );
215                                        return Poll::Ready(Err(NegotiationError::Failed));
216                                    }
217                                }
218                            }
219
220                            return Poll::Ready(Err(From::from(err)));
221                        }
222                    };
223
224                    match msg {
225                        Message::ListProtocols => {
226                            let supported =
227                                this.protocols.iter().map(|(_, p)| p).cloned().collect();
228                            let message = Message::Protocols(supported);
229                            *this.state = State::SendMessage {
230                                io,
231                                message,
232                                protocol: None,
233                            }
234                        }
235                        Message::Protocol(p) => {
236                            let protocol = this.protocols.iter().find_map(|(name, proto)| {
237                                if &p == proto {
238                                    Some(name.clone())
239                                } else {
240                                    None
241                                }
242                            });
243
244                            let message = if protocol.is_some() {
245                                tracing::debug!("Listener: confirming protocol: {}", p);
246                                Message::Protocol(p.clone())
247                            } else {
248                                tracing::debug!(
249                                    "Listener: rejecting protocol: {}",
250                                    String::from_utf8_lossy(p.as_ref())
251                                );
252                                Message::NotAvailable
253                            };
254
255                            *this.state = State::SendMessage {
256                                io,
257                                message,
258                                protocol,
259                            };
260                        }
261                        _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
262                    }
263                }
264
265                State::SendMessage {
266                    mut io,
267                    message,
268                    protocol,
269                } => {
270                    match Pin::new(&mut io).poll_ready(cx) {
271                        Poll::Pending => {
272                            *this.state = State::SendMessage {
273                                io,
274                                message,
275                                protocol,
276                            };
277                            return Poll::Pending;
278                        }
279                        Poll::Ready(Ok(())) => {}
280                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
281                    }
282
283                    if let Message::NotAvailable = &message {
284                        *this.last_sent_na = true;
285                    } else {
286                        *this.last_sent_na = false;
287                    }
288
289                    if let Err(err) = Pin::new(&mut io).start_send(message) {
290                        return Poll::Ready(Err(From::from(err)));
291                    }
292
293                    *this.state = State::Flush { io, protocol };
294                }
295
296                State::Flush { mut io, protocol } => {
297                    match Pin::new(&mut io).poll_flush(cx) {
298                        Poll::Pending => {
299                            *this.state = State::Flush { io, protocol };
300                            return Poll::Pending;
301                        }
302                        Poll::Ready(Ok(())) => {
303                            // If a protocol has been selected, finish negotiation.
304                            // Otherwise expect to receive another message.
305                            match protocol {
306                                Some(protocol) => {
307                                    tracing::debug!(
308                                        "Listener: sent confirmed protocol: {}",
309                                        String::from_utf8_lossy(protocol.as_ref())
310                                    );
311                                    let io = Negotiated::completed(io.into_inner());
312                                    return Poll::Ready(Ok((protocol, io)));
313                                }
314                                None => *this.state = State::RecvMessage { io },
315                            }
316                        }
317                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
318                    }
319                }
320
321                State::Done => panic!("State::poll called after completion"),
322            }
323        }
324    }
325}
326
327/// Result of [`listener_negotiate()`].
328#[derive(Debug)]
329pub enum ListenerSelectResult {
330    /// Requested protocol is available and substream can be accepted.
331    Accepted {
332        /// Protocol that is confirmed.
333        protocol: ProtocolName,
334
335        /// `multistream-select` message.
336        message: BytesMut,
337    },
338
339    /// Requested protocol is not available.
340    Rejected {
341        /// `multistream-select` message.
342        message: BytesMut,
343    },
344}
345
346/// Negotiate protocols for listener.
347///
348/// Parse protocols offered by the remote peer and check if any of the offered protocols match
349/// locally available protocols. If a match is found, return an encoded multistream-select
350/// response and the negotiated protocol. If parsing fails or no match is found, return an error.
351pub fn listener_negotiate<'a>(
352    supported_protocols: &'a mut impl Iterator<Item = &'a ProtocolName>,
353    payload: Bytes,
354) -> crate::Result<ListenerSelectResult> {
355    let Message::Protocols(protocols) = Message::decode(payload).map_err(|_| Error::InvalidData)?
356    else {
357        return Err(Error::NegotiationError(
358            error::NegotiationError::MultistreamSelectError(NegotiationError::Failed),
359        ));
360    };
361
362    // skip the multistream-select header because it's not part of user protocols but verify it's
363    // present
364    let mut protocol_iter = protocols.into_iter();
365    let header =
366        Protocol::try_from(&b"/multistream/1.0.0"[..]).expect("valid multitstream-select header");
367
368    if protocol_iter.next() != Some(header) {
369        return Err(Error::NegotiationError(
370            error::NegotiationError::MultistreamSelectError(NegotiationError::Failed),
371        ));
372    }
373
374    for protocol in protocol_iter {
375        tracing::trace!(
376            target: LOG_TARGET,
377            protocol = ?std::str::from_utf8(protocol.as_ref()),
378            "listener: checking protocol",
379        );
380
381        for supported in &mut *supported_protocols {
382            if protocol.as_ref() == supported.as_bytes() {
383                return Ok(ListenerSelectResult::Accepted {
384                    protocol: supported.clone(),
385                    message: encode_multistream_message(std::iter::once(Message::Protocol(
386                        protocol,
387                    )))?,
388                });
389            }
390        }
391    }
392
393    tracing::trace!(
394        target: LOG_TARGET,
395        "listener: handshake rejected, no supported protocol found",
396    );
397
398    Ok(ListenerSelectResult::Rejected {
399        message: encode_multistream_message(std::iter::once(Message::NotAvailable))?,
400    })
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn listener_negotiate_works() {
409        let mut local_protocols = vec![
410            ProtocolName::from("/13371338/proto/1"),
411            ProtocolName::from("/sup/proto/1"),
412            ProtocolName::from("/13371338/proto/2"),
413            ProtocolName::from("/13371338/proto/3"),
414            ProtocolName::from("/13371338/proto/4"),
415        ];
416        let message = encode_multistream_message(
417            vec![
418                Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()),
419                Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()),
420            ]
421            .into_iter(),
422        )
423        .unwrap()
424        .freeze();
425
426        match listener_negotiate(&mut local_protocols.iter(), message) {
427            Err(error) => panic!("error received: {error:?}"),
428            Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"),
429            Ok(ListenerSelectResult::Accepted { protocol, message }) => {
430                assert_eq!(protocol, ProtocolName::from("/13371338/proto/1"));
431            }
432        }
433    }
434
435    #[test]
436    fn invalid_message() {
437        let mut local_protocols = vec![
438            ProtocolName::from("/13371338/proto/1"),
439            ProtocolName::from("/sup/proto/1"),
440            ProtocolName::from("/13371338/proto/2"),
441            ProtocolName::from("/13371338/proto/3"),
442            ProtocolName::from("/13371338/proto/4"),
443        ];
444        let message = encode_multistream_message(std::iter::once(Message::Protocols(vec![
445            Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
446            Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
447        ])))
448        .unwrap()
449        .freeze();
450
451        match listener_negotiate(&mut local_protocols.iter(), message) {
452            Err(error) => assert!(std::matches!(error, Error::InvalidData)),
453            _ => panic!("invalid event"),
454        }
455    }
456
457    #[test]
458    fn only_header_line_received() {
459        let mut local_protocols = vec![
460            ProtocolName::from("/13371338/proto/1"),
461            ProtocolName::from("/sup/proto/1"),
462            ProtocolName::from("/13371338/proto/2"),
463            ProtocolName::from("/13371338/proto/3"),
464            ProtocolName::from("/13371338/proto/4"),
465        ];
466
467        // send only header line
468        let mut bytes = BytesMut::with_capacity(32);
469        let message = Message::Header(HeaderLine::V1);
470        let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
471
472        match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
473            Err(error) => assert!(std::matches!(
474                error,
475                Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
476                    NegotiationError::Failed
477                ))
478            )),
479            event => panic!("invalid event: {event:?}"),
480        }
481    }
482
483    #[test]
484    fn header_line_missing() {
485        let mut local_protocols = vec![
486            ProtocolName::from("/13371338/proto/1"),
487            ProtocolName::from("/sup/proto/1"),
488            ProtocolName::from("/13371338/proto/2"),
489            ProtocolName::from("/13371338/proto/3"),
490            ProtocolName::from("/13371338/proto/4"),
491        ];
492
493        // header line missing
494        let mut bytes = BytesMut::with_capacity(256);
495        let message = Message::Protocols(vec![
496            Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
497            Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
498        ]);
499        let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
500
501        match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
502            Err(error) => assert!(std::matches!(
503                error,
504                Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
505                    NegotiationError::Failed
506                ))
507            )),
508            event => panic!("invalid event: {event:?}"),
509        }
510    }
511
512    #[test]
513    fn protocol_not_supported() {
514        let mut local_protocols = vec![
515            ProtocolName::from("/13371338/proto/1"),
516            ProtocolName::from("/sup/proto/1"),
517            ProtocolName::from("/13371338/proto/2"),
518            ProtocolName::from("/13371338/proto/3"),
519            ProtocolName::from("/13371338/proto/4"),
520        ];
521        let message = encode_multistream_message(
522            vec![Message::Protocol(
523                Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(),
524            )]
525            .into_iter(),
526        )
527        .unwrap()
528        .freeze();
529
530        match listener_negotiate(&mut local_protocols.iter(), message) {
531            Err(error) => panic!("error received: {error:?}"),
532            Ok(ListenerSelectResult::Rejected { message }) => {
533                assert_eq!(
534                    message,
535                    encode_multistream_message(std::iter::once(Message::NotAvailable)).unwrap()
536                );
537            }
538            Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"),
539        }
540    }
541}