1use crate::{
24 codec::unsigned_varint::UnsignedVarint,
25 error::{self, Error, ParseError},
26 multistream_select::{
27 protocol::{
28 encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
29 },
30 Negotiated, NegotiationError, Version,
31 },
32 types::protocol::ProtocolName,
33};
34
35use bytes::BytesMut;
36use futures::prelude::*;
37use rustls::internal::msgs::hsjoiner::HandshakeJoiner;
38use std::{
39 convert::TryFrom as _,
40 iter, mem,
41 pin::Pin,
42 task::{Context, Poll},
43};
44
45const LOG_TARGET: &str = "litep2p::multistream-select";
46
47pub fn dialer_select_proto<R, I>(
61 inner: R,
62 protocols: I,
63 version: Version,
64) -> DialerSelectFuture<R, I::IntoIter>
65where
66 R: AsyncRead + AsyncWrite,
67 I: IntoIterator,
68 I::Item: AsRef<[u8]>,
69{
70 let protocols = protocols.into_iter().peekable();
71 DialerSelectFuture {
72 version,
73 protocols,
74 state: State::SendHeader {
75 io: MessageIO::new(inner),
76 },
77 }
78}
79
80#[pin_project::pin_project]
83pub struct DialerSelectFuture<R, I: Iterator> {
84 protocols: iter::Peekable<I>,
86 state: State<R, I::Item>,
87 version: Version,
88}
89
90enum State<R, N> {
91 SendHeader {
92 io: MessageIO<R>,
93 },
94 SendProtocol {
95 io: MessageIO<R>,
96 protocol: N,
97 header_received: bool,
98 },
99 FlushProtocol {
100 io: MessageIO<R>,
101 protocol: N,
102 header_received: bool,
103 },
104 AwaitProtocol {
105 io: MessageIO<R>,
106 protocol: N,
107 header_received: bool,
108 },
109 Done,
110}
111
112impl<R, I> Future for DialerSelectFuture<R, I>
113where
114 R: AsyncRead + AsyncWrite + Unpin,
118 I: Iterator,
119 I::Item: AsRef<[u8]>,
120{
121 type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
122
123 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
124 let this = self.project();
125
126 loop {
127 match mem::replace(this.state, State::Done) {
128 State::SendHeader { mut io } => {
129 match Pin::new(&mut io).poll_ready(cx)? {
130 Poll::Ready(()) => {}
131 Poll::Pending => {
132 *this.state = State::SendHeader { io };
133 return Poll::Pending;
134 }
135 }
136
137 let h = HeaderLine::from(*this.version);
138 if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) {
139 return Poll::Ready(Err(From::from(err)));
140 }
141
142 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
143
144 *this.state = State::SendProtocol {
147 io,
148 protocol,
149 header_received: false,
150 };
151 }
152
153 State::SendProtocol {
154 mut io,
155 protocol,
156 header_received,
157 } => {
158 match Pin::new(&mut io).poll_ready(cx)? {
159 Poll::Ready(()) => {}
160 Poll::Pending => {
161 *this.state = State::SendProtocol {
162 io,
163 protocol,
164 header_received,
165 };
166 return Poll::Pending;
167 }
168 }
169
170 let p = Protocol::try_from(protocol.as_ref())?;
171 if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
172 return Poll::Ready(Err(From::from(err)));
173 }
174 tracing::debug!(target: LOG_TARGET, "Dialer: Proposed protocol: {}", p);
175
176 if this.protocols.peek().is_some() {
177 *this.state = State::FlushProtocol {
178 io,
179 protocol,
180 header_received,
181 }
182 } else {
183 match this.version {
184 Version::V1 =>
185 *this.state = State::FlushProtocol {
186 io,
187 protocol,
188 header_received,
189 },
190 Version::V1Lazy => {
195 tracing::debug!(
196 target: LOG_TARGET,
197 "Dialer: Expecting proposed protocol: {}",
198 p
199 );
200 let hl = HeaderLine::from(Version::V1Lazy);
201 let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
202 return Poll::Ready(Ok((protocol, io)));
203 }
204 }
205 }
206 }
207
208 State::FlushProtocol {
209 mut io,
210 protocol,
211 header_received,
212 } => match Pin::new(&mut io).poll_flush(cx)? {
213 Poll::Ready(()) =>
214 *this.state = State::AwaitProtocol {
215 io,
216 protocol,
217 header_received,
218 },
219 Poll::Pending => {
220 *this.state = State::FlushProtocol {
221 io,
222 protocol,
223 header_received,
224 };
225 return Poll::Pending;
226 }
227 },
228
229 State::AwaitProtocol {
230 mut io,
231 protocol,
232 header_received,
233 } => {
234 let msg = match Pin::new(&mut io).poll_next(cx)? {
235 Poll::Ready(Some(msg)) => msg,
236 Poll::Pending => {
237 *this.state = State::AwaitProtocol {
238 io,
239 protocol,
240 header_received,
241 };
242 return Poll::Pending;
243 }
244 Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
248 };
249
250 match msg {
251 Message::Header(v)
252 if v == HeaderLine::from(*this.version) && !header_received =>
253 {
254 *this.state = State::AwaitProtocol {
255 io,
256 protocol,
257 header_received: true,
258 };
259 }
260 Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
261 tracing::debug!(
262 target: LOG_TARGET,
263 "Dialer: Received confirmation for protocol: {}",
264 p
265 );
266 let io = Negotiated::completed(io.into_inner());
267 return Poll::Ready(Ok((protocol, io)));
268 }
269 Message::NotAvailable => {
270 tracing::debug!(
271 target: LOG_TARGET,
272 "Dialer: Received rejection of protocol: {}",
273 String::from_utf8_lossy(protocol.as_ref())
274 );
275 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
276 *this.state = State::SendProtocol {
277 io,
278 protocol,
279 header_received,
280 }
281 }
282 _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
283 }
284 }
285
286 State::Done => panic!("State::poll called after completion"),
287 }
288 }
289 }
290}
291
292#[derive(Debug, PartialEq, Eq)]
294pub enum HandshakeResult {
295 NotReady,
297
298 Succeeded(ProtocolName),
303}
304
305#[derive(Debug)]
307enum HandshakeState {
308 WaitingResponse,
310
311 WaitingProtocol,
313}
314
315#[derive(Debug)]
317pub struct DialerState {
318 protocol: ProtocolName,
320
321 fallback_names: Vec<ProtocolName>,
323
324 state: HandshakeState,
326}
327
328impl DialerState {
329 pub fn propose(
334 protocol: ProtocolName,
335 fallback_names: Vec<ProtocolName>,
336 ) -> crate::Result<(Self, Vec<u8>)> {
337 let message = encode_multistream_message(
338 std::iter::once(protocol.clone())
339 .chain(fallback_names.clone())
340 .filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok())
341 .map(Message::Protocol),
342 )?
343 .freeze()
344 .to_vec();
345
346 Ok((
347 Self {
348 protocol,
349 fallback_names,
350 state: HandshakeState::WaitingResponse,
351 },
352 message,
353 ))
354 }
355
356 pub fn register_response(
358 &mut self,
359 payload: Vec<u8>,
360 ) -> Result<HandshakeResult, crate::error::NegotiationError> {
361 let Message::Protocols(protocols) =
362 Message::decode(payload.into()).map_err(|_| ParseError::InvalidData)?
363 else {
364 return Err(crate::error::NegotiationError::MultistreamSelectError(
365 NegotiationError::Failed,
366 ));
367 };
368
369 let mut protocol_iter = protocols.into_iter();
370 loop {
371 match (&self.state, protocol_iter.next()) {
372 (HandshakeState::WaitingResponse, None) =>
373 return Err(crate::error::NegotiationError::StateMismatch),
374 (HandshakeState::WaitingResponse, Some(protocol)) => {
375 let header = Protocol::try_from(&b"/multistream/1.0.0"[..])
376 .expect("valid multitstream-select header");
377
378 if protocol == header {
379 self.state = HandshakeState::WaitingProtocol;
380 } else {
381 return Err(crate::error::NegotiationError::MultistreamSelectError(
382 NegotiationError::Failed,
383 ));
384 }
385 }
386 (HandshakeState::WaitingProtocol, Some(protocol)) => {
387 if self.protocol.as_bytes() == protocol.as_ref() {
388 return Ok(HandshakeResult::Succeeded(self.protocol.clone()));
389 }
390
391 for fallback in &self.fallback_names {
392 if fallback.as_bytes() == protocol.as_ref() {
393 return Ok(HandshakeResult::Succeeded(fallback.clone()));
394 }
395 }
396
397 return Err(crate::error::NegotiationError::MultistreamSelectError(
398 NegotiationError::Failed,
399 ));
400 }
401 (HandshakeState::WaitingProtocol, None) => {
402 return Ok(HandshakeResult::NotReady);
403 }
404 }
405 }
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use crate::multistream_select::listener_select_proto;
413 use std::time::Duration;
414 use tokio::net::{TcpListener, TcpStream};
415
416 #[tokio::test]
417 async fn select_proto_basic() {
418 async fn run(version: Version) {
419 let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
420
421 let server: tokio::task::JoinHandle<Result<(), ()>> = tokio::spawn(async move {
422 let protos = vec!["/proto1", "/proto2"];
423 let (proto, mut io) =
424 listener_select_proto(server_connection, protos).await.unwrap();
425 assert_eq!(proto, "/proto2");
426
427 let mut out = vec![0; 32];
428 let n = io.read(&mut out).await.unwrap();
429 out.truncate(n);
430 assert_eq!(out, b"ping");
431
432 io.write_all(b"pong").await.unwrap();
433 io.flush().await.unwrap();
434
435 Ok(())
436 });
437
438 let client: tokio::task::JoinHandle<Result<(), ()>> = tokio::spawn(async move {
439 let protos = vec!["/proto3", "/proto2"];
440 let (proto, mut io) =
441 dialer_select_proto(client_connection, protos, version).await.unwrap();
442 assert_eq!(proto, "/proto2");
443
444 io.write_all(b"ping").await.unwrap();
445 io.flush().await.unwrap();
446
447 let mut out = vec![0; 32];
448 let n = io.read(&mut out).await.unwrap();
449 out.truncate(n);
450 assert_eq!(out, b"pong");
451
452 Ok(())
453 });
454
455 server.await.unwrap();
456 client.await.unwrap();
457 }
458
459 run(Version::V1).await;
460 run(Version::V1Lazy).await;
461 }
462
463 #[tokio::test]
465 async fn negotiation_failed() {
466 async fn run(
467 version: Version,
468 dial_protos: Vec<&'static str>,
469 dial_payload: Vec<u8>,
470 listen_protos: Vec<&'static str>,
471 ) {
472 let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
473
474 let server: tokio::task::JoinHandle<Result<(), ()>> = tokio::spawn(async move {
475 let io = match tokio::time::timeout(
476 Duration::from_secs(2),
477 listener_select_proto(server_connection, listen_protos),
478 )
479 .await
480 .unwrap()
481 {
482 Ok((_, io)) => io,
483 Err(NegotiationError::Failed) => return Ok(()),
484 Err(NegotiationError::ProtocolError(e)) => {
485 panic!("Unexpected protocol error {e}")
486 }
487 };
488 match io.complete().await {
489 Err(NegotiationError::Failed) => {}
490 _ => panic!(),
491 }
492
493 Ok(())
494 });
495
496 let client: tokio::task::JoinHandle<Result<(), ()>> = tokio::spawn(async move {
497 let mut io = match tokio::time::timeout(
498 Duration::from_secs(2),
499 dialer_select_proto(client_connection, dial_protos, version),
500 )
501 .await
502 .unwrap()
503 {
504 Err(NegotiationError::Failed) => return Ok(()),
505 Ok((_, io)) => io,
506 Err(_) => panic!(),
507 };
508
509 io.write_all(&dial_payload).await.unwrap();
513 match io.complete().await {
514 Err(NegotiationError::Failed) => {}
515 _ => panic!(),
516 }
517
518 Ok(())
519 });
520
521 server.await.unwrap();
522 client.await.unwrap();
523 }
524
525 run(Version::V1, vec!["/proto1"], vec![1], vec!["/proto2"]).await;
527 run(Version::V1Lazy, vec!["/proto1"], vec![1], vec!["/proto2"]).await;
528 }
529
530 #[tokio::test]
531 async fn v1_lazy_do_not_wait_for_negotiation_on_poll_close() {
532 let (client_connection, _server_connection) =
533 futures_ringbuf::Endpoint::pair(1024 * 1024, 1);
534
535 let client = tokio::spawn(async move {
536 let protos = vec!["/proto1"];
538 let (proto, mut io) =
539 dialer_select_proto(client_connection, protos, Version::V1Lazy).await.unwrap();
540 assert_eq!(proto, "/proto1");
541
542 io.close().await.unwrap();
549 });
550
551 assert!(tokio::time::timeout(Duration::from_secs(10), client).await.is_err());
554 }
555
556 #[tokio::test]
557 async fn low_level_negotiate() {
558 async fn run(version: Version) {
559 let (client_connection, mut server_connection) =
560 futures_ringbuf::Endpoint::pair(100, 100);
561
562 let server = tokio::spawn(async move {
563 let protos = vec!["/proto2"];
564
565 let multistream = b"/multistream/1.0.0\n";
566 let len = multistream.len();
567 let proto = b"/proto2\n";
568 let proto_len = proto.len();
569
570 let mut expected_message = Vec::new();
573 expected_message.push(len as u8);
574 expected_message.extend_from_slice(multistream);
575 expected_message.push(proto_len as u8);
576 expected_message.extend_from_slice(proto);
577
578 if version == Version::V1Lazy {
579 expected_message.extend_from_slice(b"ping");
580 }
581
582 let mut out = vec![0; 64];
583 let n = server_connection.read(&mut out).await.unwrap();
584 out.truncate(n);
585 assert_eq!(out, expected_message);
586
587 let mut send_message = Vec::new();
589 send_message.push(len as u8);
590 send_message.extend_from_slice(multistream);
591
592 server_connection.write_all(&mut send_message).await.unwrap();
593
594 let mut send_message = Vec::new();
595 send_message.push(proto_len as u8);
596 send_message.extend_from_slice(proto);
597 server_connection.write_all(&mut send_message).await.unwrap();
598
599 match version {
601 Version::V1 => {
602 let mut out = vec![0; 64];
603 let n = server_connection.read(&mut out).await.unwrap();
604 out.truncate(n);
605 assert_eq!(out, b"ping");
606
607 server_connection.write_all(b"pong").await.unwrap();
608 }
609 Version::V1Lazy => {
610 server_connection.write_all(b"pong").await.unwrap();
612 }
613 }
614 });
615
616 let client = tokio::spawn(async move {
617 let protos = vec!["/proto2"];
618 let (proto, mut io) =
619 dialer_select_proto(client_connection, protos, version).await.unwrap();
620 assert_eq!(proto, "/proto2");
621
622 io.write_all(b"ping").await.unwrap();
623 io.flush().await.unwrap();
624
625 let mut out = vec![0; 32];
626 let n = io.read(&mut out).await.unwrap();
627 out.truncate(n);
628 assert_eq!(out, b"pong");
629 });
630
631 server.await.unwrap();
632 client.await.unwrap();
633 }
634
635 run(Version::V1).await;
636 run(Version::V1Lazy).await;
637 }
638
639 #[tokio::test]
640 async fn v1_low_level_negotiate_multiple_headers() {
641 let (client_connection, mut server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
642
643 let server = tokio::spawn(async move {
644 let protos = vec!["/proto2"];
645
646 let multistream = b"/multistream/1.0.0\n";
647 let len = multistream.len();
648 let proto = b"/proto2\n";
649 let proto_len = proto.len();
650
651 let mut expected_message = Vec::new();
654 expected_message.push(len as u8);
655 expected_message.extend_from_slice(multistream);
656 expected_message.push(proto_len as u8);
657 expected_message.extend_from_slice(proto);
658
659 let mut out = vec![0; 64];
660 let n = server_connection.read(&mut out).await.unwrap();
661 out.truncate(n);
662 assert_eq!(out, expected_message);
663
664 let mut send_message = Vec::new();
666 send_message.push(len as u8);
667 send_message.extend_from_slice(multistream);
668
669 server_connection.write_all(&mut send_message).await.unwrap();
670
671 let mut send_message = Vec::new();
673 send_message.push(len as u8);
674 send_message.extend_from_slice(multistream);
675
676 server_connection.write_all(&mut send_message).await.unwrap();
677 });
678
679 let client = tokio::spawn(async move {
680 let protos = vec!["/proto2"];
681
682 let result =
685 dialer_select_proto(client_connection, protos, Version::V1).await.unwrap_err();
686 match result {
687 NegotiationError::ProtocolError(ProtocolError::InvalidMessage) => {}
688 _ => panic!("unexpected error: {:?}", result),
689 };
690 });
691
692 server.await.unwrap();
693 client.await.unwrap();
694 }
695
696 #[tokio::test]
697 async fn v1_lazy_low_level_negotiate_multiple_headers() {
698 let (client_connection, mut server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
699
700 let server = tokio::spawn(async move {
701 let protos = vec!["/proto2"];
702
703 let multistream = b"/multistream/1.0.0\n";
704 let len = multistream.len();
705 let proto = b"/proto2\n";
706 let proto_len = proto.len();
707
708 let mut expected_message = Vec::new();
711 expected_message.push(len as u8);
712 expected_message.extend_from_slice(multistream);
713 expected_message.push(proto_len as u8);
714 expected_message.extend_from_slice(proto);
715
716 let mut out = vec![0; 64];
717 let n = server_connection.read(&mut out).await.unwrap();
718 out.truncate(n);
719 assert_eq!(out, expected_message);
720
721 let mut send_message = Vec::new();
723 send_message.push(len as u8);
724 send_message.extend_from_slice(multistream);
725
726 server_connection.write_all(&mut send_message).await.unwrap();
727
728 let mut send_message = Vec::new();
730 send_message.push(len as u8);
731 send_message.extend_from_slice(multistream);
732
733 server_connection.write_all(&mut send_message).await.unwrap();
734 });
735
736 let client = tokio::spawn(async move {
737 let protos = vec!["/proto2"];
738
739 let (proto, to_negociate) =
742 dialer_select_proto(client_connection, protos, Version::V1Lazy).await.unwrap();
743 assert_eq!(proto, "/proto2");
744
745 let result = to_negociate.complete().await.unwrap_err();
746
747 match result {
748 NegotiationError::ProtocolError(ProtocolError::InvalidMessage) => {}
749 _ => panic!("unexpected error: {:?}", result),
750 };
751 });
752
753 server.await.unwrap();
754 client.await.unwrap();
755 }
756
757 #[test]
758 fn propose() {
759 let (mut dialer_state, message) =
760 DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
761 let message = bytes::BytesMut::from(&message[..]).freeze();
762
763 let Message::Protocols(protocols) = Message::decode(message).unwrap() else {
764 panic!("invalid message type");
765 };
766
767 assert_eq!(protocols.len(), 2);
768 assert_eq!(
769 protocols[0],
770 Protocol::try_from(&b"/multistream/1.0.0"[..])
771 .expect("valid multitstream-select header")
772 );
773 assert_eq!(
774 protocols[1],
775 Protocol::try_from(&b"/13371338/proto/1"[..])
776 .expect("valid multitstream-select header")
777 );
778 }
779
780 #[test]
781 fn propose_with_fallback() {
782 let (mut dialer_state, message) = DialerState::propose(
783 ProtocolName::from("/13371338/proto/1"),
784 vec![ProtocolName::from("/sup/proto/1")],
785 )
786 .unwrap();
787 let message = bytes::BytesMut::from(&message[..]).freeze();
788
789 let Message::Protocols(protocols) = Message::decode(message).unwrap() else {
790 panic!("invalid message type");
791 };
792
793 assert_eq!(protocols.len(), 3);
794 assert_eq!(
795 protocols[0],
796 Protocol::try_from(&b"/multistream/1.0.0"[..])
797 .expect("valid multitstream-select header")
798 );
799 assert_eq!(
800 protocols[1],
801 Protocol::try_from(&b"/13371338/proto/1"[..])
802 .expect("valid multitstream-select header")
803 );
804 assert_eq!(
805 protocols[2],
806 Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid multitstream-select header")
807 );
808 }
809
810 #[test]
811 fn register_response_invalid_message() {
812 let mut bytes = BytesMut::with_capacity(32);
814 let message = Message::Header(HeaderLine::V1);
815 let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
816
817 let (mut dialer_state, _message) =
818 DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
819
820 match dialer_state.register_response(bytes.freeze().to_vec()) {
821 Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
822 event => panic!("invalid event: {event:?}"),
823 }
824 }
825
826 #[test]
827 fn header_line_missing() {
828 let mut bytes = BytesMut::with_capacity(256);
830 let message = Message::Protocols(vec![
831 Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
832 Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
833 ]);
834 let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
835
836 let (mut dialer_state, _message) =
837 DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
838
839 match dialer_state.register_response(bytes.freeze().to_vec()) {
840 Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
841 event => panic!("invalid event: {event:?}"),
842 }
843 }
844
845 #[test]
846 fn negotiate_main_protocol() {
847 let message = encode_multistream_message(
848 vec![Message::Protocol(
849 Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
850 )]
851 .into_iter(),
852 )
853 .unwrap()
854 .freeze();
855
856 let (mut dialer_state, _message) = DialerState::propose(
857 ProtocolName::from("/13371338/proto/1"),
858 vec![ProtocolName::from("/sup/proto/1")],
859 )
860 .unwrap();
861
862 match dialer_state.register_response(message.to_vec()) {
863 Ok(HandshakeResult::Succeeded(negotiated)) =>
864 assert_eq!(negotiated, ProtocolName::from("/13371338/proto/1")),
865 _ => panic!("invalid event"),
866 }
867 }
868
869 #[test]
870 fn negotiate_fallback_protocol() {
871 let message = encode_multistream_message(
872 vec![Message::Protocol(
873 Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
874 )]
875 .into_iter(),
876 )
877 .unwrap()
878 .freeze();
879
880 let (mut dialer_state, _message) = DialerState::propose(
881 ProtocolName::from("/13371338/proto/1"),
882 vec![ProtocolName::from("/sup/proto/1")],
883 )
884 .unwrap();
885
886 match dialer_state.register_response(message.to_vec()) {
887 Ok(HandshakeResult::Succeeded(negotiated)) =>
888 assert_eq!(negotiated, ProtocolName::from("/sup/proto/1")),
889 _ => panic!("invalid event"),
890 }
891 }
892}