1use crate::behaviour::Mode;
22use crate::protocol::{
23 KadInStreamSink, KadOutStreamSink, KadPeer, KadRequestMsg, KadResponseMsg, ProtocolConfig,
24};
25use crate::record_priv::{self, Record};
26use crate::QueryId;
27use either::Either;
28use futures::prelude::*;
29use futures::stream::SelectAll;
30use instant::Instant;
31use libp2p_core::{upgrade, ConnectedPoint};
32use libp2p_identity::PeerId;
33use libp2p_swarm::handler::{
34 ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
35};
36use libp2p_swarm::{
37 ConnectionHandler, ConnectionHandlerEvent, ConnectionId, KeepAlive, Stream, StreamUpgradeError,
38 SubstreamProtocol, SupportedProtocols,
39};
40use log::trace;
41use std::collections::VecDeque;
42use std::task::Waker;
43use std::{
44 error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll, time::Duration,
45};
46
47const MAX_NUM_SUBSTREAMS: usize = 32;
48
49pub struct Handler {
57 protocol_config: ProtocolConfig,
59
60 mode: Mode,
62
63 idle_timeout: Duration,
65
66 next_connec_unique_id: UniqueConnecId,
68
69 outbound_substreams: SelectAll<OutboundSubstreamState>,
71
72 num_requested_outbound_streams: usize,
74
75 pending_messages: VecDeque<(KadRequestMsg, Option<QueryId>)>,
78
79 inbound_substreams: SelectAll<InboundSubstreamState>,
81
82 keep_alive: KeepAlive,
84
85 endpoint: ConnectedPoint,
88
89 remote_peer_id: PeerId,
91
92 protocol_status: Option<ProtocolStatus>,
94
95 remote_supported_protocols: SupportedProtocols,
96
97 connection_id: ConnectionId,
99}
100
101#[derive(Debug, Copy, Clone, PartialEq)]
104struct ProtocolStatus {
105 supported: bool,
107 reported: bool,
109}
110
111enum OutboundSubstreamState {
113 PendingSend(KadOutStreamSink<Stream>, KadRequestMsg, Option<QueryId>),
115 PendingFlush(KadOutStreamSink<Stream>, Option<QueryId>),
117 WaitingAnswer(KadOutStreamSink<Stream>, QueryId),
120 ReportError(HandlerQueryErr, QueryId),
122 Closing(KadOutStreamSink<Stream>),
124 Done,
126 Poisoned,
127}
128
129enum InboundSubstreamState {
131 WaitingMessage {
133 first: bool,
135 connection_id: UniqueConnecId,
136 substream: KadInStreamSink<Stream>,
137 },
138 WaitingBehaviour(UniqueConnecId, KadInStreamSink<Stream>, Option<Waker>),
140 PendingSend(UniqueConnecId, KadInStreamSink<Stream>, KadResponseMsg),
142 PendingFlush(UniqueConnecId, KadInStreamSink<Stream>),
144 Closing(KadInStreamSink<Stream>),
146 Cancelled,
148
149 Poisoned {
150 phantom: PhantomData<QueryId>,
151 },
152}
153
154impl InboundSubstreamState {
155 fn try_answer_with(
156 &mut self,
157 id: RequestId,
158 msg: KadResponseMsg,
159 ) -> Result<(), KadResponseMsg> {
160 match std::mem::replace(
161 self,
162 InboundSubstreamState::Poisoned {
163 phantom: PhantomData,
164 },
165 ) {
166 InboundSubstreamState::WaitingBehaviour(conn_id, substream, mut waker)
167 if conn_id == id.connec_unique_id =>
168 {
169 *self = InboundSubstreamState::PendingSend(conn_id, substream, msg);
170
171 if let Some(waker) = waker.take() {
172 waker.wake();
173 }
174
175 Ok(())
176 }
177 other => {
178 *self = other;
179
180 Err(msg)
181 }
182 }
183 }
184
185 fn close(&mut self) {
186 match std::mem::replace(
187 self,
188 InboundSubstreamState::Poisoned {
189 phantom: PhantomData,
190 },
191 ) {
192 InboundSubstreamState::WaitingMessage { substream, .. }
193 | InboundSubstreamState::WaitingBehaviour(_, substream, _)
194 | InboundSubstreamState::PendingSend(_, substream, _)
195 | InboundSubstreamState::PendingFlush(_, substream)
196 | InboundSubstreamState::Closing(substream) => {
197 *self = InboundSubstreamState::Closing(substream);
198 }
199 InboundSubstreamState::Cancelled => {
200 *self = InboundSubstreamState::Cancelled;
201 }
202 InboundSubstreamState::Poisoned { .. } => unreachable!(),
203 }
204 }
205}
206
207#[derive(Debug)]
209pub enum HandlerEvent {
210 ProtocolConfirmed { endpoint: ConnectedPoint },
213 ProtocolNotSupported { endpoint: ConnectedPoint },
216
217 FindNodeReq {
220 key: Vec<u8>,
222 request_id: RequestId,
224 },
225
226 FindNodeRes {
228 closer_peers: Vec<KadPeer>,
230 query_id: QueryId,
232 },
233
234 GetProvidersReq {
237 key: record_priv::Key,
239 request_id: RequestId,
241 },
242
243 GetProvidersRes {
245 closer_peers: Vec<KadPeer>,
247 provider_peers: Vec<KadPeer>,
249 query_id: QueryId,
251 },
252
253 QueryError {
255 error: HandlerQueryErr,
257 query_id: QueryId,
259 },
260
261 AddProvider {
263 key: record_priv::Key,
265 provider: KadPeer,
267 },
268
269 GetRecord {
271 key: record_priv::Key,
273 request_id: RequestId,
275 },
276
277 GetRecordRes {
279 record: Option<Record>,
281 closer_peers: Vec<KadPeer>,
283 query_id: QueryId,
285 },
286
287 PutRecord {
289 record: Record,
290 request_id: RequestId,
292 },
293
294 PutRecordRes {
296 key: record_priv::Key,
298 value: Vec<u8>,
300 query_id: QueryId,
302 },
303}
304
305#[derive(Debug)]
307pub enum HandlerQueryErr {
308 Upgrade(StreamUpgradeError<io::Error>),
310 UnexpectedMessage,
312 Io(io::Error),
314}
315
316impl fmt::Display for HandlerQueryErr {
317 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318 match self {
319 HandlerQueryErr::Upgrade(err) => {
320 write!(f, "Error while performing Kademlia query: {err}")
321 }
322 HandlerQueryErr::UnexpectedMessage => {
323 write!(
324 f,
325 "Remote answered our Kademlia RPC query with the wrong message type"
326 )
327 }
328 HandlerQueryErr::Io(err) => {
329 write!(f, "I/O error during a Kademlia RPC query: {err}")
330 }
331 }
332 }
333}
334
335impl error::Error for HandlerQueryErr {
336 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
337 match self {
338 HandlerQueryErr::Upgrade(err) => Some(err),
339 HandlerQueryErr::UnexpectedMessage => None,
340 HandlerQueryErr::Io(err) => Some(err),
341 }
342 }
343}
344
345impl From<StreamUpgradeError<io::Error>> for HandlerQueryErr {
346 fn from(err: StreamUpgradeError<io::Error>) -> Self {
347 HandlerQueryErr::Upgrade(err)
348 }
349}
350
351#[derive(Debug)]
353pub enum HandlerIn {
354 Reset(RequestId),
362
363 ReconfigureMode { new_mode: Mode },
365
366 FindNodeReq {
369 key: Vec<u8>,
371 query_id: QueryId,
373 },
374
375 FindNodeRes {
377 closer_peers: Vec<KadPeer>,
379 request_id: RequestId,
383 },
384
385 GetProvidersReq {
388 key: record_priv::Key,
390 query_id: QueryId,
392 },
393
394 GetProvidersRes {
396 closer_peers: Vec<KadPeer>,
398 provider_peers: Vec<KadPeer>,
400 request_id: RequestId,
404 },
405
406 AddProvider {
411 key: record_priv::Key,
413 provider: KadPeer,
415 },
416
417 GetRecord {
419 key: record_priv::Key,
421 query_id: QueryId,
423 },
424
425 GetRecordRes {
427 record: Option<Record>,
429 closer_peers: Vec<KadPeer>,
431 request_id: RequestId,
433 },
434
435 PutRecord {
437 record: Record,
438 query_id: QueryId,
440 },
441
442 PutRecordRes {
444 key: record_priv::Key,
446 value: Vec<u8>,
448 request_id: RequestId,
450 },
451}
452
453#[derive(Debug, PartialEq, Eq, Copy, Clone)]
456pub struct RequestId {
457 connec_unique_id: UniqueConnecId,
459}
460
461#[derive(Debug, Copy, Clone, PartialEq, Eq)]
463struct UniqueConnecId(u64);
464
465impl Handler {
466 pub fn new(
467 protocol_config: ProtocolConfig,
468 idle_timeout: Duration,
469 endpoint: ConnectedPoint,
470 remote_peer_id: PeerId,
471 mode: Mode,
472 connection_id: ConnectionId,
473 ) -> Self {
474 match &endpoint {
475 ConnectedPoint::Dialer { .. } => {
476 log::debug!(
477 "Operating in {mode}-mode on new outbound connection to {remote_peer_id}"
478 );
479 }
480 ConnectedPoint::Listener { .. } => {
481 log::debug!(
482 "Operating in {mode}-mode on new inbound connection to {remote_peer_id}"
483 );
484 }
485 }
486
487 #[allow(deprecated)]
488 let keep_alive = KeepAlive::Until(Instant::now() + idle_timeout);
489
490 Handler {
491 protocol_config,
492 mode,
493 idle_timeout,
494 endpoint,
495 remote_peer_id,
496 next_connec_unique_id: UniqueConnecId(0),
497 inbound_substreams: Default::default(),
498 outbound_substreams: Default::default(),
499 num_requested_outbound_streams: 0,
500 pending_messages: Default::default(),
501 keep_alive,
502 protocol_status: None,
503 remote_supported_protocols: Default::default(),
504 connection_id,
505 }
506 }
507
508 fn on_fully_negotiated_outbound(
509 &mut self,
510 FullyNegotiatedOutbound { protocol, info: () }: FullyNegotiatedOutbound<
511 <Self as ConnectionHandler>::OutboundProtocol,
512 <Self as ConnectionHandler>::OutboundOpenInfo,
513 >,
514 ) {
515 if let Some((msg, query_id)) = self.pending_messages.pop_front() {
516 self.outbound_substreams
517 .push(OutboundSubstreamState::PendingSend(protocol, msg, query_id));
518 } else {
519 debug_assert!(false, "Requested outbound stream without message")
520 }
521
522 self.num_requested_outbound_streams -= 1;
523
524 if self.protocol_status.is_none() {
525 self.protocol_status = Some(ProtocolStatus {
529 supported: true,
530 reported: false,
531 });
532 }
533 }
534
535 fn on_fully_negotiated_inbound(
536 &mut self,
537 FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound<
538 <Self as ConnectionHandler>::InboundProtocol,
539 <Self as ConnectionHandler>::InboundOpenInfo,
540 >,
541 ) {
542 let protocol = match protocol {
545 future::Either::Left(p) => p,
546 future::Either::Right(p) => void::unreachable(p),
547 };
548
549 if self.protocol_status.is_none() {
550 self.protocol_status = Some(ProtocolStatus {
554 supported: true,
555 reported: false,
556 });
557 }
558
559 if self.inbound_substreams.len() == MAX_NUM_SUBSTREAMS {
560 if let Some(s) = self.inbound_substreams.iter_mut().find(|s| {
561 matches!(
562 s,
563 InboundSubstreamState::WaitingMessage { first: false, .. }
565 )
566 }) {
567 *s = InboundSubstreamState::Cancelled;
568 log::debug!(
569 "New inbound substream to {:?} exceeds inbound substream limit. \
570 Removed older substream waiting to be reused.",
571 self.remote_peer_id,
572 )
573 } else {
574 log::warn!(
575 "New inbound substream to {:?} exceeds inbound substream limit. \
576 No older substream waiting to be reused. Dropping new substream.",
577 self.remote_peer_id,
578 );
579 return;
580 }
581 }
582
583 let connec_unique_id = self.next_connec_unique_id;
584 self.next_connec_unique_id.0 += 1;
585 self.inbound_substreams
586 .push(InboundSubstreamState::WaitingMessage {
587 first: true,
588 connection_id: connec_unique_id,
589 substream: protocol,
590 });
591 }
592
593 fn on_dial_upgrade_error(
594 &mut self,
595 DialUpgradeError {
596 info: (), error, ..
597 }: DialUpgradeError<
598 <Self as ConnectionHandler>::OutboundOpenInfo,
599 <Self as ConnectionHandler>::OutboundProtocol,
600 >,
601 ) {
602 if let Some((_, Some(query_id))) = self.pending_messages.pop_front() {
606 self.outbound_substreams
607 .push(OutboundSubstreamState::ReportError(error.into(), query_id));
608 }
609
610 self.num_requested_outbound_streams -= 1;
611 }
612}
613
614impl ConnectionHandler for Handler {
615 type FromBehaviour = HandlerIn;
616 type ToBehaviour = HandlerEvent;
617 type Error = io::Error; type InboundProtocol = Either<ProtocolConfig, upgrade::DeniedUpgrade>;
619 type OutboundProtocol = ProtocolConfig;
620 type OutboundOpenInfo = ();
621 type InboundOpenInfo = ();
622
623 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
624 match self.mode {
625 Mode::Server => SubstreamProtocol::new(Either::Left(self.protocol_config.clone()), ()),
626 Mode::Client => SubstreamProtocol::new(Either::Right(upgrade::DeniedUpgrade), ()),
627 }
628 }
629
630 fn on_behaviour_event(&mut self, message: HandlerIn) {
631 match message {
632 HandlerIn::Reset(request_id) => {
633 if let Some(state) = self
634 .inbound_substreams
635 .iter_mut()
636 .find(|state| match state {
637 InboundSubstreamState::WaitingBehaviour(conn_id, _, _) => {
638 conn_id == &request_id.connec_unique_id
639 }
640 _ => false,
641 })
642 {
643 state.close();
644 }
645 }
646 HandlerIn::FindNodeReq { key, query_id } => {
647 let msg = KadRequestMsg::FindNode { key };
648 self.pending_messages.push_back((msg, Some(query_id)));
649 }
650 HandlerIn::FindNodeRes {
651 closer_peers,
652 request_id,
653 } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
654 HandlerIn::GetProvidersReq { key, query_id } => {
655 let msg = KadRequestMsg::GetProviders { key };
656 self.pending_messages.push_back((msg, Some(query_id)));
657 }
658 HandlerIn::GetProvidersRes {
659 closer_peers,
660 provider_peers,
661 request_id,
662 } => self.answer_pending_request(
663 request_id,
664 KadResponseMsg::GetProviders {
665 closer_peers,
666 provider_peers,
667 },
668 ),
669 HandlerIn::AddProvider { key, provider } => {
670 let msg = KadRequestMsg::AddProvider { key, provider };
671 self.pending_messages.push_back((msg, None));
672 }
673 HandlerIn::GetRecord { key, query_id } => {
674 let msg = KadRequestMsg::GetValue { key };
675 self.pending_messages.push_back((msg, Some(query_id)));
676 }
677 HandlerIn::PutRecord { record, query_id } => {
678 let msg = KadRequestMsg::PutValue { record };
679 self.pending_messages.push_back((msg, Some(query_id)));
680 }
681 HandlerIn::GetRecordRes {
682 record,
683 closer_peers,
684 request_id,
685 } => {
686 self.answer_pending_request(
687 request_id,
688 KadResponseMsg::GetValue {
689 record,
690 closer_peers,
691 },
692 );
693 }
694 HandlerIn::PutRecordRes {
695 key,
696 request_id,
697 value,
698 } => {
699 self.answer_pending_request(request_id, KadResponseMsg::PutValue { key, value });
700 }
701 HandlerIn::ReconfigureMode { new_mode } => {
702 let peer = self.remote_peer_id;
703
704 match &self.endpoint {
705 ConnectedPoint::Dialer { .. } => {
706 log::debug!(
707 "Now operating in {new_mode}-mode on outbound connection with {peer}"
708 )
709 }
710 ConnectedPoint::Listener { local_addr, .. } => {
711 log::debug!("Now operating in {new_mode}-mode on inbound connection with {peer} assuming that one of our external addresses routes to {local_addr}")
712 }
713 }
714
715 self.mode = new_mode;
716 }
717 }
718 }
719
720 fn connection_keep_alive(&self) -> KeepAlive {
721 self.keep_alive
722 }
723
724 fn poll(
725 &mut self,
726 cx: &mut Context<'_>,
727 ) -> Poll<
728 ConnectionHandlerEvent<
729 Self::OutboundProtocol,
730 Self::OutboundOpenInfo,
731 Self::ToBehaviour,
732 Self::Error,
733 >,
734 > {
735 match &mut self.protocol_status {
736 Some(status) if !status.reported => {
737 status.reported = true;
738 let event = if status.supported {
739 HandlerEvent::ProtocolConfirmed {
740 endpoint: self.endpoint.clone(),
741 }
742 } else {
743 HandlerEvent::ProtocolNotSupported {
744 endpoint: self.endpoint.clone(),
745 }
746 };
747
748 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
749 }
750 _ => {}
751 }
752
753 if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) {
754 return Poll::Ready(event);
755 }
756
757 if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) {
758 return Poll::Ready(event);
759 }
760
761 let num_in_progress_outbound_substreams =
762 self.outbound_substreams.len() + self.num_requested_outbound_streams;
763 if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS
764 && self.num_requested_outbound_streams < self.pending_messages.len()
765 {
766 self.num_requested_outbound_streams += 1;
767 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
768 protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()),
769 });
770 }
771
772 let no_streams = self.outbound_substreams.is_empty() && self.inbound_substreams.is_empty();
773
774 self.keep_alive = {
775 #[allow(deprecated)]
776 match (no_streams, self.keep_alive) {
777 (true, k @ KeepAlive::Until(_)) => k,
779 (true, _) => KeepAlive::Until(Instant::now() + self.idle_timeout),
781 (false, _) => KeepAlive::Yes,
783 }
784 };
785
786 Poll::Pending
787 }
788
789 fn on_connection_event(
790 &mut self,
791 event: ConnectionEvent<
792 Self::InboundProtocol,
793 Self::OutboundProtocol,
794 Self::InboundOpenInfo,
795 Self::OutboundOpenInfo,
796 >,
797 ) {
798 match event {
799 ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
800 self.on_fully_negotiated_outbound(fully_negotiated_outbound)
801 }
802 ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
803 self.on_fully_negotiated_inbound(fully_negotiated_inbound)
804 }
805 ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
806 self.on_dial_upgrade_error(dial_upgrade_error)
807 }
808 ConnectionEvent::AddressChange(_)
809 | ConnectionEvent::ListenUpgradeError(_)
810 | ConnectionEvent::LocalProtocolsChange(_) => {}
811 ConnectionEvent::RemoteProtocolsChange(change) => {
812 let dirty = self.remote_supported_protocols.on_protocols_change(change);
813
814 if dirty {
815 let remote_supports_our_kademlia_protocols = self
816 .remote_supported_protocols
817 .iter()
818 .any(|p| self.protocol_config.protocol_names().contains(p));
819
820 self.protocol_status = Some(compute_new_protocol_status(
821 remote_supports_our_kademlia_protocols,
822 self.protocol_status,
823 self.remote_peer_id,
824 self.connection_id,
825 ))
826 }
827 }
828 }
829 }
830}
831
832fn compute_new_protocol_status(
833 now_supported: bool,
834 current_status: Option<ProtocolStatus>,
835 remote_peer_id: PeerId,
836 connection_id: ConnectionId,
837) -> ProtocolStatus {
838 let current_status = match current_status {
839 None => {
840 return ProtocolStatus {
841 supported: now_supported,
842 reported: false,
843 }
844 }
845 Some(current) => current,
846 };
847
848 if now_supported == current_status.supported {
849 return ProtocolStatus {
850 supported: now_supported,
851 reported: true,
852 };
853 }
854
855 if now_supported {
856 log::debug!("Remote {remote_peer_id} now supports our kademlia protocol on connection {connection_id}");
857 } else {
858 log::debug!("Remote {remote_peer_id} no longer supports our kademlia protocol on connection {connection_id}");
859 }
860
861 ProtocolStatus {
862 supported: now_supported,
863 reported: false,
864 }
865}
866
867impl Handler {
868 fn answer_pending_request(&mut self, request_id: RequestId, mut msg: KadResponseMsg) {
869 for state in self.inbound_substreams.iter_mut() {
870 match state.try_answer_with(request_id, msg) {
871 Ok(()) => return,
872 Err(m) => {
873 msg = m;
874 }
875 }
876 }
877
878 debug_assert!(false, "Cannot find inbound substream for {request_id:?}")
879 }
880}
881
882impl futures::Stream for OutboundSubstreamState {
883 type Item = ConnectionHandlerEvent<ProtocolConfig, (), HandlerEvent, io::Error>;
884
885 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
886 let this = self.get_mut();
887
888 loop {
889 match std::mem::replace(this, OutboundSubstreamState::Poisoned) {
890 OutboundSubstreamState::PendingSend(mut substream, msg, query_id) => {
891 match substream.poll_ready_unpin(cx) {
892 Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) {
893 Ok(()) => {
894 *this = OutboundSubstreamState::PendingFlush(substream, query_id);
895 }
896 Err(error) => {
897 *this = OutboundSubstreamState::Done;
898 let event = query_id.map(|query_id| {
899 ConnectionHandlerEvent::NotifyBehaviour(
900 HandlerEvent::QueryError {
901 error: HandlerQueryErr::Io(error),
902 query_id,
903 },
904 )
905 });
906
907 return Poll::Ready(event);
908 }
909 },
910 Poll::Pending => {
911 *this = OutboundSubstreamState::PendingSend(substream, msg, query_id);
912 return Poll::Pending;
913 }
914 Poll::Ready(Err(error)) => {
915 *this = OutboundSubstreamState::Done;
916 let event = query_id.map(|query_id| {
917 ConnectionHandlerEvent::NotifyBehaviour(HandlerEvent::QueryError {
918 error: HandlerQueryErr::Io(error),
919 query_id,
920 })
921 });
922
923 return Poll::Ready(event);
924 }
925 }
926 }
927 OutboundSubstreamState::PendingFlush(mut substream, query_id) => {
928 match substream.poll_flush_unpin(cx) {
929 Poll::Ready(Ok(())) => {
930 if let Some(query_id) = query_id {
931 *this = OutboundSubstreamState::WaitingAnswer(substream, query_id);
932 } else {
933 *this = OutboundSubstreamState::Closing(substream);
934 }
935 }
936 Poll::Pending => {
937 *this = OutboundSubstreamState::PendingFlush(substream, query_id);
938 return Poll::Pending;
939 }
940 Poll::Ready(Err(error)) => {
941 *this = OutboundSubstreamState::Done;
942 let event = query_id.map(|query_id| {
943 ConnectionHandlerEvent::NotifyBehaviour(HandlerEvent::QueryError {
944 error: HandlerQueryErr::Io(error),
945 query_id,
946 })
947 });
948
949 return Poll::Ready(event);
950 }
951 }
952 }
953 OutboundSubstreamState::WaitingAnswer(mut substream, query_id) => {
954 match substream.poll_next_unpin(cx) {
955 Poll::Ready(Some(Ok(msg))) => {
956 *this = OutboundSubstreamState::Closing(substream);
957 let event = process_kad_response(msg, query_id);
958
959 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
960 event,
961 )));
962 }
963 Poll::Pending => {
964 *this = OutboundSubstreamState::WaitingAnswer(substream, query_id);
965 return Poll::Pending;
966 }
967 Poll::Ready(Some(Err(error))) => {
968 *this = OutboundSubstreamState::Done;
969 let event = HandlerEvent::QueryError {
970 error: HandlerQueryErr::Io(error),
971 query_id,
972 };
973
974 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
975 event,
976 )));
977 }
978 Poll::Ready(None) => {
979 *this = OutboundSubstreamState::Done;
980 let event = HandlerEvent::QueryError {
981 error: HandlerQueryErr::Io(io::ErrorKind::UnexpectedEof.into()),
982 query_id,
983 };
984
985 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
986 event,
987 )));
988 }
989 }
990 }
991 OutboundSubstreamState::ReportError(error, query_id) => {
992 *this = OutboundSubstreamState::Done;
993 let event = HandlerEvent::QueryError { error, query_id };
994
995 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(event)));
996 }
997 OutboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) {
998 Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None),
999 Poll::Pending => {
1000 *this = OutboundSubstreamState::Closing(stream);
1001 return Poll::Pending;
1002 }
1003 },
1004 OutboundSubstreamState::Done => {
1005 *this = OutboundSubstreamState::Done;
1006 return Poll::Ready(None);
1007 }
1008 OutboundSubstreamState::Poisoned => unreachable!(),
1009 }
1010 }
1011 }
1012}
1013
1014impl futures::Stream for InboundSubstreamState {
1015 type Item = ConnectionHandlerEvent<ProtocolConfig, (), HandlerEvent, io::Error>;
1016
1017 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1018 let this = self.get_mut();
1019
1020 loop {
1021 match std::mem::replace(
1022 this,
1023 Self::Poisoned {
1024 phantom: PhantomData,
1025 },
1026 ) {
1027 InboundSubstreamState::WaitingMessage {
1028 first,
1029 connection_id,
1030 mut substream,
1031 } => match substream.poll_next_unpin(cx) {
1032 Poll::Ready(Some(Ok(KadRequestMsg::Ping))) => {
1033 log::warn!("Kademlia PING messages are unsupported");
1034
1035 *this = InboundSubstreamState::Closing(substream);
1036 }
1037 Poll::Ready(Some(Ok(KadRequestMsg::FindNode { key }))) => {
1038 *this =
1039 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
1040 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1041 HandlerEvent::FindNodeReq {
1042 key,
1043 request_id: RequestId {
1044 connec_unique_id: connection_id,
1045 },
1046 },
1047 )));
1048 }
1049 Poll::Ready(Some(Ok(KadRequestMsg::GetProviders { key }))) => {
1050 *this =
1051 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
1052 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1053 HandlerEvent::GetProvidersReq {
1054 key,
1055 request_id: RequestId {
1056 connec_unique_id: connection_id,
1057 },
1058 },
1059 )));
1060 }
1061 Poll::Ready(Some(Ok(KadRequestMsg::AddProvider { key, provider }))) => {
1062 *this = InboundSubstreamState::WaitingMessage {
1063 first: false,
1064 connection_id,
1065 substream,
1066 };
1067 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1068 HandlerEvent::AddProvider { key, provider },
1069 )));
1070 }
1071 Poll::Ready(Some(Ok(KadRequestMsg::GetValue { key }))) => {
1072 *this =
1073 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
1074 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1075 HandlerEvent::GetRecord {
1076 key,
1077 request_id: RequestId {
1078 connec_unique_id: connection_id,
1079 },
1080 },
1081 )));
1082 }
1083 Poll::Ready(Some(Ok(KadRequestMsg::PutValue { record }))) => {
1084 *this =
1085 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
1086 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1087 HandlerEvent::PutRecord {
1088 record,
1089 request_id: RequestId {
1090 connec_unique_id: connection_id,
1091 },
1092 },
1093 )));
1094 }
1095 Poll::Pending => {
1096 *this = InboundSubstreamState::WaitingMessage {
1097 first,
1098 connection_id,
1099 substream,
1100 };
1101 return Poll::Pending;
1102 }
1103 Poll::Ready(None) => {
1104 return Poll::Ready(None);
1105 }
1106 Poll::Ready(Some(Err(e))) => {
1107 trace!("Inbound substream error: {:?}", e);
1108 return Poll::Ready(None);
1109 }
1110 },
1111 InboundSubstreamState::WaitingBehaviour(id, substream, _) => {
1112 *this = InboundSubstreamState::WaitingBehaviour(
1113 id,
1114 substream,
1115 Some(cx.waker().clone()),
1116 );
1117
1118 return Poll::Pending;
1119 }
1120 InboundSubstreamState::PendingSend(id, mut substream, msg) => {
1121 match substream.poll_ready_unpin(cx) {
1122 Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) {
1123 Ok(()) => {
1124 *this = InboundSubstreamState::PendingFlush(id, substream);
1125 }
1126 Err(_) => return Poll::Ready(None),
1127 },
1128 Poll::Pending => {
1129 *this = InboundSubstreamState::PendingSend(id, substream, msg);
1130 return Poll::Pending;
1131 }
1132 Poll::Ready(Err(_)) => return Poll::Ready(None),
1133 }
1134 }
1135 InboundSubstreamState::PendingFlush(id, mut substream) => {
1136 match substream.poll_flush_unpin(cx) {
1137 Poll::Ready(Ok(())) => {
1138 *this = InboundSubstreamState::WaitingMessage {
1139 first: false,
1140 connection_id: id,
1141 substream,
1142 };
1143 }
1144 Poll::Pending => {
1145 *this = InboundSubstreamState::PendingFlush(id, substream);
1146 return Poll::Pending;
1147 }
1148 Poll::Ready(Err(_)) => return Poll::Ready(None),
1149 }
1150 }
1151 InboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) {
1152 Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None),
1153 Poll::Pending => {
1154 *this = InboundSubstreamState::Closing(stream);
1155 return Poll::Pending;
1156 }
1157 },
1158 InboundSubstreamState::Poisoned { .. } => unreachable!(),
1159 InboundSubstreamState::Cancelled => return Poll::Ready(None),
1160 }
1161 }
1162 }
1163}
1164
1165fn process_kad_response(event: KadResponseMsg, query_id: QueryId) -> HandlerEvent {
1167 match event {
1169 KadResponseMsg::Pong => {
1170 HandlerEvent::QueryError {
1172 error: HandlerQueryErr::UnexpectedMessage,
1173 query_id,
1174 }
1175 }
1176 KadResponseMsg::FindNode { closer_peers } => HandlerEvent::FindNodeRes {
1177 closer_peers,
1178 query_id,
1179 },
1180 KadResponseMsg::GetProviders {
1181 closer_peers,
1182 provider_peers,
1183 } => HandlerEvent::GetProvidersRes {
1184 closer_peers,
1185 provider_peers,
1186 query_id,
1187 },
1188 KadResponseMsg::GetValue {
1189 record,
1190 closer_peers,
1191 } => HandlerEvent::GetRecordRes {
1192 record,
1193 closer_peers,
1194 query_id,
1195 },
1196 KadResponseMsg::PutValue { key, value, .. } => HandlerEvent::PutRecordRes {
1197 key,
1198 value,
1199 query_id,
1200 },
1201 }
1202}
1203
1204#[cfg(test)]
1205mod tests {
1206 use super::*;
1207 use quickcheck::{Arbitrary, Gen};
1208
1209 impl Arbitrary for ProtocolStatus {
1210 fn arbitrary(g: &mut Gen) -> Self {
1211 Self {
1212 supported: bool::arbitrary(g),
1213 reported: bool::arbitrary(g),
1214 }
1215 }
1216 }
1217
1218 #[test]
1219 fn compute_next_protocol_status_test() {
1220 let _ = env_logger::try_init();
1221
1222 fn prop(now_supported: bool, current: Option<ProtocolStatus>) {
1223 let new = compute_new_protocol_status(
1224 now_supported,
1225 current,
1226 PeerId::random(),
1227 ConnectionId::new_unchecked(0),
1228 );
1229
1230 match current {
1231 None => {
1232 assert!(!new.reported);
1233 assert_eq!(new.supported, now_supported);
1234 }
1235 Some(current) => {
1236 if current.supported == now_supported {
1237 assert!(new.reported);
1238 } else {
1239 assert!(!new.reported);
1240 }
1241
1242 assert_eq!(new.supported, now_supported);
1243 }
1244 }
1245 }
1246
1247 quickcheck::quickcheck(prop as fn(_, _))
1248 }
1249}