1use crate::{
25 codec::unsigned_varint::UnsignedVarint,
26 error::{self, Error},
27 multistream_select::{
28 protocol::{
29 encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
30 },
31 Negotiated, NegotiationError,
32 },
33 types::protocol::ProtocolName,
34};
35
36use bytes::{Bytes, BytesMut};
37use futures::prelude::*;
38use smallvec::SmallVec;
39use std::{
40 convert::TryFrom as _,
41 iter::FromIterator,
42 mem,
43 pin::Pin,
44 task::{Context, Poll},
45};
46
47const LOG_TARGET: &str = "litep2p::multistream-select";
48
49pub fn listener_select_proto<R, I>(inner: R, protocols: I) -> ListenerSelectFuture<R, I::Item>
57where
58 R: AsyncRead + AsyncWrite,
59 I: IntoIterator,
60 I::Item: AsRef<[u8]>,
61{
62 let protocols = protocols.into_iter().filter_map(|n| match Protocol::try_from(n.as_ref()) {
63 Ok(p) => Some((n, p)),
64 Err(e) => {
65 tracing::warn!(
66 target: LOG_TARGET,
67 "Listener: Ignoring invalid protocol: {} due to {}",
68 String::from_utf8_lossy(n.as_ref()),
69 e
70 );
71 None
72 }
73 });
74 ListenerSelectFuture {
75 protocols: SmallVec::from_iter(protocols),
76 state: State::RecvHeader {
77 io: MessageIO::new(inner),
78 },
79 last_sent_na: false,
80 }
81}
82
83#[pin_project::pin_project]
86pub struct ListenerSelectFuture<R, N> {
87 protocols: SmallVec<[(N, Protocol); 8]>,
90 state: State<R, N>,
91 last_sent_na: bool,
98}
99
100enum State<R, N> {
101 RecvHeader {
102 io: MessageIO<R>,
103 },
104 SendHeader {
105 io: MessageIO<R>,
106 },
107 RecvMessage {
108 io: MessageIO<R>,
109 },
110 SendMessage {
111 io: MessageIO<R>,
112 message: Message,
113 protocol: Option<N>,
114 },
115 Flush {
116 io: MessageIO<R>,
117 protocol: Option<N>,
118 },
119 Done,
120}
121
122impl<R, N> Future for ListenerSelectFuture<R, N>
123where
124 R: AsyncRead + AsyncWrite + Unpin,
129 N: AsRef<[u8]> + Clone,
130{
131 type Output = Result<(N, Negotiated<R>), NegotiationError>;
132
133 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
134 let this = self.project();
135
136 loop {
137 match mem::replace(this.state, State::Done) {
138 State::RecvHeader { mut io } => {
139 match io.poll_next_unpin(cx) {
140 Poll::Ready(Some(Ok(Message::Header(h)))) => match h {
141 HeaderLine::V1 => *this.state = State::SendHeader { io },
142 },
143 Poll::Ready(Some(Ok(_))) =>
144 return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
145 Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
146 Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
150 Poll::Pending => {
151 *this.state = State::RecvHeader { io };
152 return Poll::Pending;
153 }
154 }
155 }
156
157 State::SendHeader { mut io } => {
158 match Pin::new(&mut io).poll_ready(cx) {
159 Poll::Pending => {
160 *this.state = State::SendHeader { io };
161 return Poll::Pending;
162 }
163 Poll::Ready(Ok(())) => {}
164 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
165 }
166
167 let msg = Message::Header(HeaderLine::V1);
168 if let Err(err) = Pin::new(&mut io).start_send(msg) {
169 return Poll::Ready(Err(From::from(err)));
170 }
171
172 *this.state = State::Flush { io, protocol: None };
173 }
174
175 State::RecvMessage { mut io } => {
176 let msg = match Pin::new(&mut io).poll_next(cx) {
177 Poll::Ready(Some(Ok(msg))) => msg,
178 Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
188 Poll::Pending => {
189 *this.state = State::RecvMessage { io };
190 return Poll::Pending;
191 }
192 Poll::Ready(Some(Err(err))) => {
193 if *this.last_sent_na {
194 if let ProtocolError::InvalidMessage = &err {
201 tracing::trace!(
202 target: LOG_TARGET,
203 "Listener: Negotiation failed with invalid \
204 message after protocol rejection."
205 );
206 return Poll::Ready(Err(NegotiationError::Failed));
207 }
208 if let ProtocolError::IoError(e) = &err {
209 if e.kind() == std::io::ErrorKind::UnexpectedEof {
210 tracing::trace!(
211 target: LOG_TARGET,
212 "Listener: Negotiation failed with EOF \
213 after protocol rejection."
214 );
215 return Poll::Ready(Err(NegotiationError::Failed));
216 }
217 }
218 }
219
220 return Poll::Ready(Err(From::from(err)));
221 }
222 };
223
224 match msg {
225 Message::ListProtocols => {
226 let supported =
227 this.protocols.iter().map(|(_, p)| p).cloned().collect();
228 let message = Message::Protocols(supported);
229 *this.state = State::SendMessage {
230 io,
231 message,
232 protocol: None,
233 }
234 }
235 Message::Protocol(p) => {
236 let protocol = this.protocols.iter().find_map(|(name, proto)| {
237 if &p == proto {
238 Some(name.clone())
239 } else {
240 None
241 }
242 });
243
244 let message = if protocol.is_some() {
245 tracing::debug!("Listener: confirming protocol: {}", p);
246 Message::Protocol(p.clone())
247 } else {
248 tracing::debug!(
249 "Listener: rejecting protocol: {}",
250 String::from_utf8_lossy(p.as_ref())
251 );
252 Message::NotAvailable
253 };
254
255 *this.state = State::SendMessage {
256 io,
257 message,
258 protocol,
259 };
260 }
261 _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
262 }
263 }
264
265 State::SendMessage {
266 mut io,
267 message,
268 protocol,
269 } => {
270 match Pin::new(&mut io).poll_ready(cx) {
271 Poll::Pending => {
272 *this.state = State::SendMessage {
273 io,
274 message,
275 protocol,
276 };
277 return Poll::Pending;
278 }
279 Poll::Ready(Ok(())) => {}
280 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
281 }
282
283 if let Message::NotAvailable = &message {
284 *this.last_sent_na = true;
285 } else {
286 *this.last_sent_na = false;
287 }
288
289 if let Err(err) = Pin::new(&mut io).start_send(message) {
290 return Poll::Ready(Err(From::from(err)));
291 }
292
293 *this.state = State::Flush { io, protocol };
294 }
295
296 State::Flush { mut io, protocol } => {
297 match Pin::new(&mut io).poll_flush(cx) {
298 Poll::Pending => {
299 *this.state = State::Flush { io, protocol };
300 return Poll::Pending;
301 }
302 Poll::Ready(Ok(())) => {
303 match protocol {
306 Some(protocol) => {
307 tracing::debug!(
308 "Listener: sent confirmed protocol: {}",
309 String::from_utf8_lossy(protocol.as_ref())
310 );
311 let io = Negotiated::completed(io.into_inner());
312 return Poll::Ready(Ok((protocol, io)));
313 }
314 None => *this.state = State::RecvMessage { io },
315 }
316 }
317 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
318 }
319 }
320
321 State::Done => panic!("State::poll called after completion"),
322 }
323 }
324 }
325}
326
327#[derive(Debug)]
329pub enum ListenerSelectResult {
330 Accepted {
332 protocol: ProtocolName,
334
335 message: BytesMut,
337 },
338
339 Rejected {
341 message: BytesMut,
343 },
344}
345
346pub fn listener_negotiate<'a>(
352 supported_protocols: &'a mut impl Iterator<Item = &'a ProtocolName>,
353 payload: Bytes,
354) -> crate::Result<ListenerSelectResult> {
355 let Message::Protocols(protocols) = Message::decode(payload).map_err(|_| Error::InvalidData)?
356 else {
357 return Err(Error::NegotiationError(
358 error::NegotiationError::MultistreamSelectError(NegotiationError::Failed),
359 ));
360 };
361
362 let mut protocol_iter = protocols.into_iter();
365 let header =
366 Protocol::try_from(&b"/multistream/1.0.0"[..]).expect("valid multitstream-select header");
367
368 if protocol_iter.next() != Some(header) {
369 return Err(Error::NegotiationError(
370 error::NegotiationError::MultistreamSelectError(NegotiationError::Failed),
371 ));
372 }
373
374 for protocol in protocol_iter {
375 tracing::trace!(
376 target: LOG_TARGET,
377 protocol = ?std::str::from_utf8(protocol.as_ref()),
378 "listener: checking protocol",
379 );
380
381 for supported in &mut *supported_protocols {
382 if protocol.as_ref() == supported.as_bytes() {
383 return Ok(ListenerSelectResult::Accepted {
384 protocol: supported.clone(),
385 message: encode_multistream_message(std::iter::once(Message::Protocol(
386 protocol,
387 )))?,
388 });
389 }
390 }
391 }
392
393 tracing::trace!(
394 target: LOG_TARGET,
395 "listener: handshake rejected, no supported protocol found",
396 );
397
398 Ok(ListenerSelectResult::Rejected {
399 message: encode_multistream_message(std::iter::once(Message::NotAvailable))?,
400 })
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn listener_negotiate_works() {
409 let mut local_protocols = vec![
410 ProtocolName::from("/13371338/proto/1"),
411 ProtocolName::from("/sup/proto/1"),
412 ProtocolName::from("/13371338/proto/2"),
413 ProtocolName::from("/13371338/proto/3"),
414 ProtocolName::from("/13371338/proto/4"),
415 ];
416 let message = encode_multistream_message(
417 vec![
418 Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()),
419 Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()),
420 ]
421 .into_iter(),
422 )
423 .unwrap()
424 .freeze();
425
426 match listener_negotiate(&mut local_protocols.iter(), message) {
427 Err(error) => panic!("error received: {error:?}"),
428 Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"),
429 Ok(ListenerSelectResult::Accepted { protocol, message }) => {
430 assert_eq!(protocol, ProtocolName::from("/13371338/proto/1"));
431 }
432 }
433 }
434
435 #[test]
436 fn invalid_message() {
437 let mut local_protocols = vec![
438 ProtocolName::from("/13371338/proto/1"),
439 ProtocolName::from("/sup/proto/1"),
440 ProtocolName::from("/13371338/proto/2"),
441 ProtocolName::from("/13371338/proto/3"),
442 ProtocolName::from("/13371338/proto/4"),
443 ];
444 let message = encode_multistream_message(std::iter::once(Message::Protocols(vec![
445 Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
446 Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
447 ])))
448 .unwrap()
449 .freeze();
450
451 match listener_negotiate(&mut local_protocols.iter(), message) {
452 Err(error) => assert!(std::matches!(error, Error::InvalidData)),
453 _ => panic!("invalid event"),
454 }
455 }
456
457 #[test]
458 fn only_header_line_received() {
459 let mut local_protocols = vec![
460 ProtocolName::from("/13371338/proto/1"),
461 ProtocolName::from("/sup/proto/1"),
462 ProtocolName::from("/13371338/proto/2"),
463 ProtocolName::from("/13371338/proto/3"),
464 ProtocolName::from("/13371338/proto/4"),
465 ];
466
467 let mut bytes = BytesMut::with_capacity(32);
469 let message = Message::Header(HeaderLine::V1);
470 let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
471
472 match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
473 Err(error) => assert!(std::matches!(
474 error,
475 Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
476 NegotiationError::Failed
477 ))
478 )),
479 event => panic!("invalid event: {event:?}"),
480 }
481 }
482
483 #[test]
484 fn header_line_missing() {
485 let mut local_protocols = vec![
486 ProtocolName::from("/13371338/proto/1"),
487 ProtocolName::from("/sup/proto/1"),
488 ProtocolName::from("/13371338/proto/2"),
489 ProtocolName::from("/13371338/proto/3"),
490 ProtocolName::from("/13371338/proto/4"),
491 ];
492
493 let mut bytes = BytesMut::with_capacity(256);
495 let message = Message::Protocols(vec![
496 Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
497 Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
498 ]);
499 let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();
500
501 match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
502 Err(error) => assert!(std::matches!(
503 error,
504 Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
505 NegotiationError::Failed
506 ))
507 )),
508 event => panic!("invalid event: {event:?}"),
509 }
510 }
511
512 #[test]
513 fn protocol_not_supported() {
514 let mut local_protocols = vec![
515 ProtocolName::from("/13371338/proto/1"),
516 ProtocolName::from("/sup/proto/1"),
517 ProtocolName::from("/13371338/proto/2"),
518 ProtocolName::from("/13371338/proto/3"),
519 ProtocolName::from("/13371338/proto/4"),
520 ];
521 let message = encode_multistream_message(
522 vec![Message::Protocol(
523 Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(),
524 )]
525 .into_iter(),
526 )
527 .unwrap()
528 .freeze();
529
530 match listener_negotiate(&mut local_protocols.iter(), message) {
531 Err(error) => panic!("error received: {error:?}"),
532 Ok(ListenerSelectResult::Rejected { message }) => {
533 assert_eq!(
534 message,
535 encode_multistream_message(std::iter::once(Message::NotAvailable)).unwrap()
536 );
537 }
538 Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"),
539 }
540 }
541}