litep2p/multistream_select/
dialer_select.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Protocol negotiation strategies for the peer acting as the dialer.
22
23use crate::{
24    codec::unsigned_varint::UnsignedVarint,
25    error::{self, Error, ParseError},
26    multistream_select::{
27        protocol::{
28            encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
29        },
30        Negotiated, NegotiationError, Version,
31    },
32    types::protocol::ProtocolName,
33};
34
35use bytes::BytesMut;
36use futures::prelude::*;
37use rustls::internal::msgs::hsjoiner::HandshakeJoiner;
38use std::{
39    convert::TryFrom as _,
40    iter, mem,
41    pin::Pin,
42    task::{Context, Poll},
43};
44
45const LOG_TARGET: &str = "litep2p::multistream-select";
46
47/// Returns a `Future` that negotiates a protocol on the given I/O stream
48/// for a peer acting as the _dialer_ (or _initiator_).
49///
50/// This function is given an I/O stream and a list of protocols and returns a
51/// computation that performs the protocol negotiation with the remote. The
52/// returned `Future` resolves with the name of the negotiated protocol and
53/// a [`Negotiated`] I/O stream.
54///
55/// Within the scope of this library, a dialer always commits to a specific
56/// multistream-select [`Version`], whereas a listener always supports
57/// all versions supported by this library. Frictionless multistream-select
58/// protocol upgrades may thus proceed by deployments with updated listeners,
59/// eventually followed by deployments of dialers choosing the newer protocol.
60pub fn dialer_select_proto<R, I>(
61    inner: R,
62    protocols: I,
63    version: Version,
64) -> DialerSelectFuture<R, I::IntoIter>
65where
66    R: AsyncRead + AsyncWrite,
67    I: IntoIterator,
68    I::Item: AsRef<[u8]>,
69{
70    let protocols = protocols.into_iter().peekable();
71    DialerSelectFuture {
72        version,
73        protocols,
74        state: State::SendHeader {
75            io: MessageIO::new(inner),
76        },
77    }
78}
79
80/// A `Future` returned by [`dialer_select_proto`] which negotiates
81/// a protocol iteratively by considering one protocol after the other.
82#[pin_project::pin_project]
83pub struct DialerSelectFuture<R, I: Iterator> {
84    // TODO: It would be nice if eventually N = I::Item = Protocol.
85    protocols: iter::Peekable<I>,
86    state: State<R, I::Item>,
87    version: Version,
88}
89
90enum State<R, N> {
91    SendHeader {
92        io: MessageIO<R>,
93    },
94    SendProtocol {
95        io: MessageIO<R>,
96        protocol: N,
97        header_received: bool,
98    },
99    FlushProtocol {
100        io: MessageIO<R>,
101        protocol: N,
102        header_received: bool,
103    },
104    AwaitProtocol {
105        io: MessageIO<R>,
106        protocol: N,
107        header_received: bool,
108    },
109    Done,
110}
111
112impl<R, I> Future for DialerSelectFuture<R, I>
113where
114    // The Unpin bound here is required because we produce
115    // a `Negotiated<R>` as the output. It also makes
116    // the implementation considerably easier to write.
117    R: AsyncRead + AsyncWrite + Unpin,
118    I: Iterator,
119    I::Item: AsRef<[u8]>,
120{
121    type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
122
123    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
124        let this = self.project();
125
126        loop {
127            match mem::replace(this.state, State::Done) {
128                State::SendHeader { mut io } => {
129                    match Pin::new(&mut io).poll_ready(cx)? {
130                        Poll::Ready(()) => {}
131                        Poll::Pending => {
132                            *this.state = State::SendHeader { io };
133                            return Poll::Pending;
134                        }
135                    }
136
137                    let h = HeaderLine::from(*this.version);
138                    if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) {
139                        return Poll::Ready(Err(From::from(err)));
140                    }
141
142                    let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
143
144                    // The dialer always sends the header and the first protocol
145                    // proposal in one go for efficiency.
146                    *this.state = State::SendProtocol {
147                        io,
148                        protocol,
149                        header_received: false,
150                    };
151                }
152
153                State::SendProtocol {
154                    mut io,
155                    protocol,
156                    header_received,
157                } => {
158                    match Pin::new(&mut io).poll_ready(cx)? {
159                        Poll::Ready(()) => {}
160                        Poll::Pending => {
161                            *this.state = State::SendProtocol {
162                                io,
163                                protocol,
164                                header_received,
165                            };
166                            return Poll::Pending;
167                        }
168                    }
169
170                    let p = Protocol::try_from(protocol.as_ref())?;
171                    if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
172                        return Poll::Ready(Err(From::from(err)));
173                    }
174                    tracing::debug!(target: LOG_TARGET, "Dialer: Proposed protocol: {}", p);
175
176                    if this.protocols.peek().is_some() {
177                        *this.state = State::FlushProtocol {
178                            io,
179                            protocol,
180                            header_received,
181                        }
182                    } else {
183                        match this.version {
184                            Version::V1 =>
185                                *this.state = State::FlushProtocol {
186                                    io,
187                                    protocol,
188                                    header_received,
189                                },
190                            // This is the only effect that `V1Lazy` has compared to `V1`:
191                            // Optimistically settling on the only protocol that
192                            // the dialer supports for this negotiation. Notably,
193                            // the dialer expects a regular `V1` response.
194                            Version::V1Lazy => {
195                                tracing::debug!(
196                                    target: LOG_TARGET,
197                                    "Dialer: Expecting proposed protocol: {}",
198                                    p
199                                );
200                                let hl = HeaderLine::from(Version::V1Lazy);
201                                let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
202                                return Poll::Ready(Ok((protocol, io)));
203                            }
204                        }
205                    }
206                }
207
208                State::FlushProtocol {
209                    mut io,
210                    protocol,
211                    header_received,
212                } => match Pin::new(&mut io).poll_flush(cx)? {
213                    Poll::Ready(()) =>
214                        *this.state = State::AwaitProtocol {
215                            io,
216                            protocol,
217                            header_received,
218                        },
219                    Poll::Pending => {
220                        *this.state = State::FlushProtocol {
221                            io,
222                            protocol,
223                            header_received,
224                        };
225                        return Poll::Pending;
226                    }
227                },
228
229                State::AwaitProtocol {
230                    mut io,
231                    protocol,
232                    header_received,
233                } => {
234                    let msg = match Pin::new(&mut io).poll_next(cx)? {
235                        Poll::Ready(Some(msg)) => msg,
236                        Poll::Pending => {
237                            *this.state = State::AwaitProtocol {
238                                io,
239                                protocol,
240                                header_received,
241                            };
242                            return Poll::Pending;
243                        }
244                        // Treat EOF error as [`NegotiationError::Failed`], not as
245                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
246                        // stream as a permissible way to "gracefully" fail a negotiation.
247                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
248                    };
249
250                    match msg {
251                        Message::Header(v)
252                            if v == HeaderLine::from(*this.version) && !header_received =>
253                        {
254                            *this.state = State::AwaitProtocol {
255                                io,
256                                protocol,
257                                header_received: true,
258                            };
259                        }
260                        Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
261                            tracing::debug!(
262                                target: LOG_TARGET,
263                                "Dialer: Received confirmation for protocol: {}",
264                                p
265                            );
266                            let io = Negotiated::completed(io.into_inner());
267                            return Poll::Ready(Ok((protocol, io)));
268                        }
269                        Message::NotAvailable => {
270                            tracing::debug!(
271                                target: LOG_TARGET,
272                                "Dialer: Received rejection of protocol: {}",
273                                String::from_utf8_lossy(protocol.as_ref())
274                            );
275                            let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
276                            *this.state = State::SendProtocol {
277                                io,
278                                protocol,
279                                header_received,
280                            }
281                        }
282                        _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
283                    }
284                }
285
286                State::Done => panic!("State::poll called after completion"),
287            }
288        }
289    }
290}
291
292/// `multistream-select` handshake result for dialer.
293#[derive(Debug, PartialEq, Eq)]
294pub enum HandshakeResult {
295    /// Handshake is not complete, data missing.
296    NotReady,
297
298    /// Handshake has succeeded.
299    ///
300    /// The returned tuple contains the negotiated protocol and response
301    /// that must be sent to remote peer.
302    Succeeded(ProtocolName),
303}
304
305/// Handshake state.
306#[derive(Debug)]
307enum HandshakeState {
308    /// Wainting to receive any response from remote peer.
309    WaitingResponse,
310
311    /// Waiting to receive the actual application protocol from remote peer.
312    WaitingProtocol,
313}
314
315/// `multistream-select` dialer handshake state.
316#[derive(Debug)]
317pub struct DialerState {
318    /// Proposed main protocol.
319    protocol: ProtocolName,
320
321    /// Fallback names of the main protocol.
322    fallback_names: Vec<ProtocolName>,
323
324    /// Dialer handshake state.
325    state: HandshakeState,
326}
327
328impl DialerState {
329    /// Propose protocol to remote peer.
330    ///
331    /// Return [`DialerState`] which is used to drive forward the negotiation and an encoded
332    /// `multistream-select` message that contains the protocol proposal for the substream.
333    pub fn propose(
334        protocol: ProtocolName,
335        fallback_names: Vec<ProtocolName>,
336    ) -> crate::Result<(Self, Vec<u8>)> {
337        let message = encode_multistream_message(
338            std::iter::once(protocol.clone())
339                .chain(fallback_names.clone())
340                .filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok())
341                .map(Message::Protocol),
342        )?
343        .freeze()
344        .to_vec();
345
346        Ok((
347            Self {
348                protocol,
349                fallback_names,
350                state: HandshakeState::WaitingResponse,
351            },
352            message,
353        ))
354    }
355
356    /// Register response to [`DialerState`].
357    pub fn register_response(
358        &mut self,
359        payload: Vec<u8>,
360    ) -> Result<HandshakeResult, crate::error::NegotiationError> {
361        let Message::Protocols(protocols) =
362            Message::decode(payload.into()).map_err(|_| ParseError::InvalidData)?
363        else {
364            return Err(crate::error::NegotiationError::MultistreamSelectError(
365                NegotiationError::Failed,
366            ));
367        };
368
369        let mut protocol_iter = protocols.into_iter();
370        loop {
371            match (&self.state, protocol_iter.next()) {
372                (HandshakeState::WaitingResponse, None) =>
373                    return Err(crate::error::NegotiationError::StateMismatch),
374                (HandshakeState::WaitingResponse, Some(protocol)) => {
375                    let header = Protocol::try_from(&b"/multistream/1.0.0"[..])
376                        .expect("valid multitstream-select header");
377
378                    if protocol == header {
379                        self.state = HandshakeState::WaitingProtocol;
380                    } else {
381                        return Err(crate::error::NegotiationError::MultistreamSelectError(
382                            NegotiationError::Failed,
383                        ));
384                    }
385                }
386                (HandshakeState::WaitingProtocol, Some(protocol)) => {
387                    if self.protocol.as_bytes() == protocol.as_ref() {
388                        return Ok(HandshakeResult::Succeeded(self.protocol.clone()));
389                    }
390
391                    for fallback in &self.fallback_names {
392                        if fallback.as_bytes() == protocol.as_ref() {
393                            return Ok(HandshakeResult::Succeeded(fallback.clone()));
394                        }
395                    }
396
397                    return Err(crate::error::NegotiationError::MultistreamSelectError(
398                        NegotiationError::Failed,
399                    ));
400                }
401                (HandshakeState::WaitingProtocol, None) => {
402                    return Ok(HandshakeResult::NotReady);
403                }
404            }
405        }
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use crate::multistream_select::listener_select_proto;
413    use std::time::Duration;
414    use tokio::net::{TcpListener, TcpStream};
415
416    #[tokio::test]
417    async fn select_proto_basic() {
418        async fn run(version: Version) {
419            let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
420
421            let server: tokio::task::JoinHandle<Result<(), ()>> = tokio::spawn(async move {
422                let protos = vec!["/proto1", "/proto2"];
423                let (proto, mut io) =
424                    listener_select_proto(server_connection, protos).await.unwrap();
425                assert_eq!(proto, "/proto2");
426
427                let mut out = vec![0; 32];
428                let n = io.read(&mut out).await.unwrap();
429                out.truncate(n);
430                assert_eq!(out, b"ping");
431
432                io.write_all(b"pong").await.unwrap();
433                io.flush().await.unwrap();
434
435                Ok(())
436            });
437
438            let client: tokio::task::JoinHandle<Result<(), ()>> = tokio::spawn(async move {
439                let protos = vec!["/proto3", "/proto2"];
440                let (proto, mut io) =
441                    dialer_select_proto(client_connection, protos, version).await.unwrap();
442                assert_eq!(proto, "/proto2");
443
444                io.write_all(b"ping").await.unwrap();
445                io.flush().await.unwrap();
446
447                let mut out = vec![0; 32];
448                let n = io.read(&mut out).await.unwrap();
449                out.truncate(n);
450                assert_eq!(out, b"pong");
451
452                Ok(())
453            });
454
455            server.await.unwrap();
456            client.await.unwrap();
457        }
458
459        run(Version::V1).await;
460        run(Version::V1Lazy).await;
461    }
462
463    /// Tests the expected behaviour of failed negotiations.
464    #[tokio::test]
465    async fn negotiation_failed() {
466        async fn run(
467            version: Version,
468            dial_protos: Vec<&'static str>,
469            dial_payload: Vec<u8>,
470            listen_protos: Vec<&'static str>,
471        ) {
472            let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
473
474            let server: tokio::task::JoinHandle<Result<(), ()>> = tokio::spawn(async move {
475                let io = match tokio::time::timeout(
476                    Duration::from_secs(2),
477                    listener_select_proto(server_connection, listen_protos),
478                )
479                .await
480                .unwrap()
481                {
482                    Ok((_, io)) => io,
483                    Err(NegotiationError::Failed) => return Ok(()),
484                    Err(NegotiationError::ProtocolError(e)) => {
485                        panic!("Unexpected protocol error {e}")
486                    }
487                };
488                match io.complete().await {
489                    Err(NegotiationError::Failed) => {}
490                    _ => panic!(),
491                }
492
493                Ok(())
494            });
495
496            let client: tokio::task::JoinHandle<Result<(), ()>> = tokio::spawn(async move {
497                let mut io = match tokio::time::timeout(
498                    Duration::from_secs(2),
499                    dialer_select_proto(client_connection, dial_protos, version),
500                )
501                .await
502                .unwrap()
503                {
504                    Err(NegotiationError::Failed) => return Ok(()),
505                    Ok((_, io)) => io,
506                    Err(_) => panic!(),
507                };
508
509                // The dialer may write a payload that is even sent before it
510                // got confirmation of the last proposed protocol, when `V1Lazy`
511                // is used.
512                io.write_all(&dial_payload).await.unwrap();
513                match io.complete().await {
514                    Err(NegotiationError::Failed) => {}
515                    _ => panic!(),
516                }
517
518                Ok(())
519            });
520
521            server.await.unwrap();
522            client.await.unwrap();
523        }
524
525        // Incompatible protocols.
526        run(Version::V1, vec!["/proto1"], vec![1], vec!["/proto2"]).await;
527        run(Version::V1Lazy, vec!["/proto1"], vec![1], vec!["/proto2"]).await;
528    }
529
530    #[tokio::test]
531    async fn v1_lazy_do_not_wait_for_negotiation_on_poll_close() {
532        let (client_connection, _server_connection) =
533            futures_ringbuf::Endpoint::pair(1024 * 1024, 1);
534
535        let client = tokio::spawn(async move {
536            // Single protocol to allow for lazy (or optimistic) protocol negotiation.
537            let protos = vec!["/proto1"];
538            let (proto, mut io) =
539                dialer_select_proto(client_connection, protos, Version::V1Lazy).await.unwrap();
540            assert_eq!(proto, "/proto1");
541
542            // In Libp2p the lazy negotation of protocols can be closed at any time,
543            // even if the negotiation is not yet done.
544
545            // However, for the Litep2p the negotation must conclude before closing the
546            // lazy negotation of protocol. We'll wait for the close until the
547            // server has produced a message, in this test that means forever.
548            io.close().await.unwrap();
549        });
550
551        // TODO: Once https://github.com/paritytech/litep2p/pull/62 is merged, this
552        // should be changed to `is_ok`.
553        assert!(tokio::time::timeout(Duration::from_secs(10), client).await.is_err());
554    }
555
556    #[tokio::test]
557    async fn low_level_negotiate() {
558        async fn run(version: Version) {
559            let (client_connection, mut server_connection) =
560                futures_ringbuf::Endpoint::pair(100, 100);
561
562            let server = tokio::spawn(async move {
563                let protos = vec!["/proto2"];
564
565                let multistream = b"/multistream/1.0.0\n";
566                let len = multistream.len();
567                let proto = b"/proto2\n";
568                let proto_len = proto.len();
569
570                // Check that our implementation writes optimally
571                // the multistream ++ protocol in a single message.
572                let mut expected_message = Vec::new();
573                expected_message.push(len as u8);
574                expected_message.extend_from_slice(multistream);
575                expected_message.push(proto_len as u8);
576                expected_message.extend_from_slice(proto);
577
578                if version == Version::V1Lazy {
579                    expected_message.extend_from_slice(b"ping");
580                }
581
582                let mut out = vec![0; 64];
583                let n = server_connection.read(&mut out).await.unwrap();
584                out.truncate(n);
585                assert_eq!(out, expected_message);
586
587                // We must send the back the multistream packet.
588                let mut send_message = Vec::new();
589                send_message.push(len as u8);
590                send_message.extend_from_slice(multistream);
591
592                server_connection.write_all(&mut send_message).await.unwrap();
593
594                let mut send_message = Vec::new();
595                send_message.push(proto_len as u8);
596                send_message.extend_from_slice(proto);
597                server_connection.write_all(&mut send_message).await.unwrap();
598
599                // Handle handshake.
600                match version {
601                    Version::V1 => {
602                        let mut out = vec![0; 64];
603                        let n = server_connection.read(&mut out).await.unwrap();
604                        out.truncate(n);
605                        assert_eq!(out, b"ping");
606
607                        server_connection.write_all(b"pong").await.unwrap();
608                    }
609                    Version::V1Lazy => {
610                        // Ping (handshake) payload expected in the initial message.
611                        server_connection.write_all(b"pong").await.unwrap();
612                    }
613                }
614            });
615
616            let client = tokio::spawn(async move {
617                let protos = vec!["/proto2"];
618                let (proto, mut io) =
619                    dialer_select_proto(client_connection, protos, version).await.unwrap();
620                assert_eq!(proto, "/proto2");
621
622                io.write_all(b"ping").await.unwrap();
623                io.flush().await.unwrap();
624
625                let mut out = vec![0; 32];
626                let n = io.read(&mut out).await.unwrap();
627                out.truncate(n);
628                assert_eq!(out, b"pong");
629            });
630
631            server.await.unwrap();
632            client.await.unwrap();
633        }
634
635        run(Version::V1).await;
636        run(Version::V1Lazy).await;
637    }
638
639    #[tokio::test]
640    async fn v1_low_level_negotiate_multiple_headers() {
641        let (client_connection, mut server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
642
643        let server = tokio::spawn(async move {
644            let protos = vec!["/proto2"];
645
646            let multistream = b"/multistream/1.0.0\n";
647            let len = multistream.len();
648            let proto = b"/proto2\n";
649            let proto_len = proto.len();
650
651            // Check that our implementation writes optimally
652            // the multistream ++ protocol in a single message.
653            let mut expected_message = Vec::new();
654            expected_message.push(len as u8);
655            expected_message.extend_from_slice(multistream);
656            expected_message.push(proto_len as u8);
657            expected_message.extend_from_slice(proto);
658
659            let mut out = vec![0; 64];
660            let n = server_connection.read(&mut out).await.unwrap();
661            out.truncate(n);
662            assert_eq!(out, expected_message);
663
664            // We must send the back the multistream packet.
665            let mut send_message = Vec::new();
666            send_message.push(len as u8);
667            send_message.extend_from_slice(multistream);
668
669            server_connection.write_all(&mut send_message).await.unwrap();
670
671            // We must send the back the multistream packet again.
672            let mut send_message = Vec::new();
673            send_message.push(len as u8);
674            send_message.extend_from_slice(multistream);
675
676            server_connection.write_all(&mut send_message).await.unwrap();
677        });
678
679        let client = tokio::spawn(async move {
680            let protos = vec!["/proto2"];
681
682            // Negotiation fails because the protocol receives the `/multistream/1.0.0` header
683            // multiple times.
684            let result =
685                dialer_select_proto(client_connection, protos, Version::V1).await.unwrap_err();
686            match result {
687                NegotiationError::ProtocolError(ProtocolError::InvalidMessage) => {}
688                _ => panic!("unexpected error: {:?}", result),
689            };
690        });
691
692        server.await.unwrap();
693        client.await.unwrap();
694    }
695
696    #[tokio::test]
697    async fn v1_lazy_low_level_negotiate_multiple_headers() {
698        let (client_connection, mut server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
699
700        let server = tokio::spawn(async move {
701            let protos = vec!["/proto2"];
702
703            let multistream = b"/multistream/1.0.0\n";
704            let len = multistream.len();
705            let proto = b"/proto2\n";
706            let proto_len = proto.len();
707
708            // Check that our implementation writes optimally
709            // the multistream ++ protocol in a single message.
710            let mut expected_message = Vec::new();
711            expected_message.push(len as u8);
712            expected_message.extend_from_slice(multistream);
713            expected_message.push(proto_len as u8);
714            expected_message.extend_from_slice(proto);
715
716            let mut out = vec![0; 64];
717            let n = server_connection.read(&mut out).await.unwrap();
718            out.truncate(n);
719            assert_eq!(out, expected_message);
720
721            // We must send the back the multistream packet.
722            let mut send_message = Vec::new();
723            send_message.push(len as u8);
724            send_message.extend_from_slice(multistream);
725
726            server_connection.write_all(&mut send_message).await.unwrap();
727
728            // We must send the back the multistream packet again.
729            let mut send_message = Vec::new();
730            send_message.push(len as u8);
731            send_message.extend_from_slice(multistream);
732
733            server_connection.write_all(&mut send_message).await.unwrap();
734        });
735
736        let client = tokio::spawn(async move {
737            let protos = vec!["/proto2"];
738
739            // Negotiation fails because the protocol receives the `/multistream/1.0.0` header
740            // multiple times.
741            let (proto, to_negociate) =
742                dialer_select_proto(client_connection, protos, Version::V1Lazy).await.unwrap();
743            assert_eq!(proto, "/proto2");
744
745            let result = to_negociate.complete().await.unwrap_err();
746
747            match result {
748                NegotiationError::ProtocolError(ProtocolError::InvalidMessage) => {}
749                _ => panic!("unexpected error: {:?}", result),
750            };
751        });
752
753        server.await.unwrap();
754        client.await.unwrap();
755    }
756
757    #[test]
758    fn propose() {
759        let (mut dialer_state, message) =
760            DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
761        let message = bytes::BytesMut::from(&message[..]).freeze();
762
763        let Message::Protocols(protocols) = Message::decode(message).unwrap() else {
764            panic!("invalid message type");
765        };
766
767        assert_eq!(protocols.len(), 2);
768        assert_eq!(
769            protocols[0],
770            Protocol::try_from(&b"/multistream/1.0.0"[..])
771                .expect("valid multitstream-select header")
772        );
773        assert_eq!(
774            protocols[1],
775            Protocol::try_from(&b"/13371338/proto/1"[..])
776                .expect("valid multitstream-select header")
777        );
778    }
779
780    #[test]
781    fn propose_with_fallback() {
782        let (mut dialer_state, message) = DialerState::propose(
783            ProtocolName::from("/13371338/proto/1"),
784            vec![ProtocolName::from("/sup/proto/1")],
785        )
786        .unwrap();
787        let message = bytes::BytesMut::from(&message[..]).freeze();
788
789        let Message::Protocols(protocols) = Message::decode(message).unwrap() else {
790            panic!("invalid message type");
791        };
792
793        assert_eq!(protocols.len(), 3);
794        assert_eq!(
795            protocols[0],
796            Protocol::try_from(&b"/multistream/1.0.0"[..])
797                .expect("valid multitstream-select header")
798        );
799        assert_eq!(
800            protocols[1],
801            Protocol::try_from(&b"/13371338/proto/1"[..])
802                .expect("valid multitstream-select header")
803        );
804        assert_eq!(
805            protocols[2],
806            Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid multitstream-select header")
807        );
808    }
809
810    #[test]
811    fn register_response_invalid_message() {
812        // send only header line
813        let mut bytes = BytesMut::with_capacity(32);
814        let message = Message::Header(HeaderLine::V1);
815        let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
816
817        let (mut dialer_state, _message) =
818            DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
819
820        match dialer_state.register_response(bytes.freeze().to_vec()) {
821            Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
822            event => panic!("invalid event: {event:?}"),
823        }
824    }
825
826    #[test]
827    fn header_line_missing() {
828        // header line missing
829        let mut bytes = BytesMut::with_capacity(256);
830        let message = Message::Protocols(vec![
831            Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
832            Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
833        ]);
834        let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
835
836        let (mut dialer_state, _message) =
837            DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
838
839        match dialer_state.register_response(bytes.freeze().to_vec()) {
840            Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
841            event => panic!("invalid event: {event:?}"),
842        }
843    }
844
845    #[test]
846    fn negotiate_main_protocol() {
847        let message = encode_multistream_message(
848            vec![Message::Protocol(
849                Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
850            )]
851            .into_iter(),
852        )
853        .unwrap()
854        .freeze();
855
856        let (mut dialer_state, _message) = DialerState::propose(
857            ProtocolName::from("/13371338/proto/1"),
858            vec![ProtocolName::from("/sup/proto/1")],
859        )
860        .unwrap();
861
862        match dialer_state.register_response(message.to_vec()) {
863            Ok(HandshakeResult::Succeeded(negotiated)) =>
864                assert_eq!(negotiated, ProtocolName::from("/13371338/proto/1")),
865            _ => panic!("invalid event"),
866        }
867    }
868
869    #[test]
870    fn negotiate_fallback_protocol() {
871        let message = encode_multistream_message(
872            vec![Message::Protocol(
873                Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
874            )]
875            .into_iter(),
876        )
877        .unwrap()
878        .freeze();
879
880        let (mut dialer_state, _message) = DialerState::propose(
881            ProtocolName::from("/13371338/proto/1"),
882            vec![ProtocolName::from("/sup/proto/1")],
883        )
884        .unwrap();
885
886        match dialer_state.register_response(message.to_vec()) {
887            Ok(HandshakeResult::Succeeded(negotiated)) =>
888                assert_eq!(negotiated, ProtocolName::from("/sup/proto/1")),
889            _ => panic!("invalid event"),
890        }
891    }
892}