libp2p_kad/
handler.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::behaviour::Mode;
22use crate::protocol::{
23    KadInStreamSink, KadOutStreamSink, KadPeer, KadRequestMsg, KadResponseMsg, ProtocolConfig,
24};
25use crate::record_priv::{self, Record};
26use crate::QueryId;
27use either::Either;
28use futures::prelude::*;
29use futures::stream::SelectAll;
30use instant::Instant;
31use libp2p_core::{upgrade, ConnectedPoint};
32use libp2p_identity::PeerId;
33use libp2p_swarm::handler::{
34    ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
35};
36use libp2p_swarm::{
37    ConnectionHandler, ConnectionHandlerEvent, ConnectionId, KeepAlive, Stream, StreamUpgradeError,
38    SubstreamProtocol, SupportedProtocols,
39};
40use log::trace;
41use std::collections::VecDeque;
42use std::task::Waker;
43use std::{
44    error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll, time::Duration,
45};
46
47const MAX_NUM_SUBSTREAMS: usize = 32;
48
49/// Protocol handler that manages substreams for the Kademlia protocol
50/// on a single connection with a peer.
51///
52/// The handler will automatically open a Kademlia substream with the remote for each request we
53/// make.
54///
55/// It also handles requests made by the remote.
56pub struct Handler {
57    /// Configuration of the wire protocol.
58    protocol_config: ProtocolConfig,
59
60    /// In client mode, we don't accept inbound substreams.
61    mode: Mode,
62
63    /// Time after which we close an idle connection.
64    idle_timeout: Duration,
65
66    /// Next unique ID of a connection.
67    next_connec_unique_id: UniqueConnecId,
68
69    /// List of active outbound substreams with the state they are in.
70    outbound_substreams: SelectAll<OutboundSubstreamState>,
71
72    /// Number of outbound streams being upgraded right now.
73    num_requested_outbound_streams: usize,
74
75    /// List of outbound substreams that are waiting to become active next.
76    /// Contains the request we want to send, and the user data if we expect an answer.
77    pending_messages: VecDeque<(KadRequestMsg, Option<QueryId>)>,
78
79    /// List of active inbound substreams with the state they are in.
80    inbound_substreams: SelectAll<InboundSubstreamState>,
81
82    /// Until when to keep the connection alive.
83    keep_alive: KeepAlive,
84
85    /// The connected endpoint of the connection that the handler
86    /// is associated with.
87    endpoint: ConnectedPoint,
88
89    /// The [`PeerId`] of the remote.
90    remote_peer_id: PeerId,
91
92    /// The current state of protocol confirmation.
93    protocol_status: Option<ProtocolStatus>,
94
95    remote_supported_protocols: SupportedProtocols,
96
97    /// The ID of this connection.
98    connection_id: ConnectionId,
99}
100
101/// The states of protocol confirmation that a connection
102/// handler transitions through.
103#[derive(Debug, Copy, Clone, PartialEq)]
104struct ProtocolStatus {
105    /// Whether the remote node supports one of our kademlia protocols.
106    supported: bool,
107    /// Whether we reported the state to the behaviour.
108    reported: bool,
109}
110
111/// State of an active outbound substream.
112enum OutboundSubstreamState {
113    /// Waiting to send a message to the remote.
114    PendingSend(KadOutStreamSink<Stream>, KadRequestMsg, Option<QueryId>),
115    /// Waiting to flush the substream so that the data arrives to the remote.
116    PendingFlush(KadOutStreamSink<Stream>, Option<QueryId>),
117    /// Waiting for an answer back from the remote.
118    // TODO: add timeout
119    WaitingAnswer(KadOutStreamSink<Stream>, QueryId),
120    /// An error happened on the substream and we should report the error to the user.
121    ReportError(HandlerQueryErr, QueryId),
122    /// The substream is being closed.
123    Closing(KadOutStreamSink<Stream>),
124    /// The substream is complete and will not perform any more work.
125    Done,
126    Poisoned,
127}
128
129/// State of an active inbound substream.
130enum InboundSubstreamState {
131    /// Waiting for a request from the remote.
132    WaitingMessage {
133        /// Whether it is the first message to be awaited on this stream.
134        first: bool,
135        connection_id: UniqueConnecId,
136        substream: KadInStreamSink<Stream>,
137    },
138    /// Waiting for the behaviour to send a [`HandlerIn`] event containing the response.
139    WaitingBehaviour(UniqueConnecId, KadInStreamSink<Stream>, Option<Waker>),
140    /// Waiting to send an answer back to the remote.
141    PendingSend(UniqueConnecId, KadInStreamSink<Stream>, KadResponseMsg),
142    /// Waiting to flush an answer back to the remote.
143    PendingFlush(UniqueConnecId, KadInStreamSink<Stream>),
144    /// The substream is being closed.
145    Closing(KadInStreamSink<Stream>),
146    /// The substream was cancelled in favor of a new one.
147    Cancelled,
148
149    Poisoned {
150        phantom: PhantomData<QueryId>,
151    },
152}
153
154impl InboundSubstreamState {
155    fn try_answer_with(
156        &mut self,
157        id: RequestId,
158        msg: KadResponseMsg,
159    ) -> Result<(), KadResponseMsg> {
160        match std::mem::replace(
161            self,
162            InboundSubstreamState::Poisoned {
163                phantom: PhantomData,
164            },
165        ) {
166            InboundSubstreamState::WaitingBehaviour(conn_id, substream, mut waker)
167                if conn_id == id.connec_unique_id =>
168            {
169                *self = InboundSubstreamState::PendingSend(conn_id, substream, msg);
170
171                if let Some(waker) = waker.take() {
172                    waker.wake();
173                }
174
175                Ok(())
176            }
177            other => {
178                *self = other;
179
180                Err(msg)
181            }
182        }
183    }
184
185    fn close(&mut self) {
186        match std::mem::replace(
187            self,
188            InboundSubstreamState::Poisoned {
189                phantom: PhantomData,
190            },
191        ) {
192            InboundSubstreamState::WaitingMessage { substream, .. }
193            | InboundSubstreamState::WaitingBehaviour(_, substream, _)
194            | InboundSubstreamState::PendingSend(_, substream, _)
195            | InboundSubstreamState::PendingFlush(_, substream)
196            | InboundSubstreamState::Closing(substream) => {
197                *self = InboundSubstreamState::Closing(substream);
198            }
199            InboundSubstreamState::Cancelled => {
200                *self = InboundSubstreamState::Cancelled;
201            }
202            InboundSubstreamState::Poisoned { .. } => unreachable!(),
203        }
204    }
205}
206
207/// Event produced by the Kademlia handler.
208#[derive(Debug)]
209pub enum HandlerEvent {
210    /// The configured protocol name has been confirmed by the peer through
211    /// a successfully negotiated substream or by learning the supported protocols of the remote.
212    ProtocolConfirmed { endpoint: ConnectedPoint },
213    /// The configured protocol name(s) are not or no longer supported by the peer on the provided
214    /// connection and it should be removed from the routing table.
215    ProtocolNotSupported { endpoint: ConnectedPoint },
216
217    /// Request for the list of nodes whose IDs are the closest to `key`. The number of nodes
218    /// returned is not specified, but should be around 20.
219    FindNodeReq {
220        /// The key for which to locate the closest nodes.
221        key: Vec<u8>,
222        /// Identifier of the request. Needs to be passed back when answering.
223        request_id: RequestId,
224    },
225
226    /// Response to an `HandlerIn::FindNodeReq`.
227    FindNodeRes {
228        /// Results of the request.
229        closer_peers: Vec<KadPeer>,
230        /// The user data passed to the `FindNodeReq`.
231        query_id: QueryId,
232    },
233
234    /// Same as `FindNodeReq`, but should also return the entries of the local providers list for
235    /// this key.
236    GetProvidersReq {
237        /// The key for which providers are requested.
238        key: record_priv::Key,
239        /// Identifier of the request. Needs to be passed back when answering.
240        request_id: RequestId,
241    },
242
243    /// Response to an `HandlerIn::GetProvidersReq`.
244    GetProvidersRes {
245        /// Nodes closest to the key.
246        closer_peers: Vec<KadPeer>,
247        /// Known providers for this key.
248        provider_peers: Vec<KadPeer>,
249        /// The user data passed to the `GetProvidersReq`.
250        query_id: QueryId,
251    },
252
253    /// An error happened when performing a query.
254    QueryError {
255        /// The error that happened.
256        error: HandlerQueryErr,
257        /// The user data passed to the query.
258        query_id: QueryId,
259    },
260
261    /// The peer announced itself as a provider of a key.
262    AddProvider {
263        /// The key for which the peer is a provider of the associated value.
264        key: record_priv::Key,
265        /// The peer that is the provider of the value for `key`.
266        provider: KadPeer,
267    },
268
269    /// Request to get a value from the dht records
270    GetRecord {
271        /// Key for which we should look in the dht
272        key: record_priv::Key,
273        /// Identifier of the request. Needs to be passed back when answering.
274        request_id: RequestId,
275    },
276
277    /// Response to a `HandlerIn::GetRecord`.
278    GetRecordRes {
279        /// The result is present if the key has been found
280        record: Option<Record>,
281        /// Nodes closest to the key.
282        closer_peers: Vec<KadPeer>,
283        /// The user data passed to the `GetValue`.
284        query_id: QueryId,
285    },
286
287    /// Request to put a value in the dht records
288    PutRecord {
289        record: Record,
290        /// Identifier of the request. Needs to be passed back when answering.
291        request_id: RequestId,
292    },
293
294    /// Response to a request to store a record.
295    PutRecordRes {
296        /// The key of the stored record.
297        key: record_priv::Key,
298        /// The value of the stored record.
299        value: Vec<u8>,
300        /// The user data passed to the `PutValue`.
301        query_id: QueryId,
302    },
303}
304
305/// Error that can happen when requesting an RPC query.
306#[derive(Debug)]
307pub enum HandlerQueryErr {
308    /// Error while trying to perform the query.
309    Upgrade(StreamUpgradeError<io::Error>),
310    /// Received an answer that doesn't correspond to the request.
311    UnexpectedMessage,
312    /// I/O error in the substream.
313    Io(io::Error),
314}
315
316impl fmt::Display for HandlerQueryErr {
317    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318        match self {
319            HandlerQueryErr::Upgrade(err) => {
320                write!(f, "Error while performing Kademlia query: {err}")
321            }
322            HandlerQueryErr::UnexpectedMessage => {
323                write!(
324                    f,
325                    "Remote answered our Kademlia RPC query with the wrong message type"
326                )
327            }
328            HandlerQueryErr::Io(err) => {
329                write!(f, "I/O error during a Kademlia RPC query: {err}")
330            }
331        }
332    }
333}
334
335impl error::Error for HandlerQueryErr {
336    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
337        match self {
338            HandlerQueryErr::Upgrade(err) => Some(err),
339            HandlerQueryErr::UnexpectedMessage => None,
340            HandlerQueryErr::Io(err) => Some(err),
341        }
342    }
343}
344
345impl From<StreamUpgradeError<io::Error>> for HandlerQueryErr {
346    fn from(err: StreamUpgradeError<io::Error>) -> Self {
347        HandlerQueryErr::Upgrade(err)
348    }
349}
350
351/// Event to send to the handler.
352#[derive(Debug)]
353pub enum HandlerIn {
354    /// Resets the (sub)stream associated with the given request ID,
355    /// thus signaling an error to the remote.
356    ///
357    /// Explicitly resetting the (sub)stream associated with a request
358    /// can be used as an alternative to letting requests simply time
359    /// out on the remote peer, thus potentially avoiding some delay
360    /// for the query on the remote.
361    Reset(RequestId),
362
363    /// Change the connection to the specified mode.
364    ReconfigureMode { new_mode: Mode },
365
366    /// Request for the list of nodes whose IDs are the closest to `key`. The number of nodes
367    /// returned is not specified, but should be around 20.
368    FindNodeReq {
369        /// Identifier of the node.
370        key: Vec<u8>,
371        /// Custom user data. Passed back in the out event when the results arrive.
372        query_id: QueryId,
373    },
374
375    /// Response to a `FindNodeReq`.
376    FindNodeRes {
377        /// Results of the request.
378        closer_peers: Vec<KadPeer>,
379        /// Identifier of the request that was made by the remote.
380        ///
381        /// It is a logic error to use an id of the handler of a different node.
382        request_id: RequestId,
383    },
384
385    /// Same as `FindNodeReq`, but should also return the entries of the local providers list for
386    /// this key.
387    GetProvidersReq {
388        /// Identifier being searched.
389        key: record_priv::Key,
390        /// Custom user data. Passed back in the out event when the results arrive.
391        query_id: QueryId,
392    },
393
394    /// Response to a `GetProvidersReq`.
395    GetProvidersRes {
396        /// Nodes closest to the key.
397        closer_peers: Vec<KadPeer>,
398        /// Known providers for this key.
399        provider_peers: Vec<KadPeer>,
400        /// Identifier of the request that was made by the remote.
401        ///
402        /// It is a logic error to use an id of the handler of a different node.
403        request_id: RequestId,
404    },
405
406    /// Indicates that this provider is known for this key.
407    ///
408    /// The API of the handler doesn't expose any event that allows you to know whether this
409    /// succeeded.
410    AddProvider {
411        /// Key for which we should add providers.
412        key: record_priv::Key,
413        /// Known provider for this key.
414        provider: KadPeer,
415    },
416
417    /// Request to retrieve a record from the DHT.
418    GetRecord {
419        /// The key of the record.
420        key: record_priv::Key,
421        /// Custom data. Passed back in the out event when the results arrive.
422        query_id: QueryId,
423    },
424
425    /// Response to a `GetRecord` request.
426    GetRecordRes {
427        /// The value that might have been found in our storage.
428        record: Option<Record>,
429        /// Nodes that are closer to the key we were searching for.
430        closer_peers: Vec<KadPeer>,
431        /// Identifier of the request that was made by the remote.
432        request_id: RequestId,
433    },
434
435    /// Put a value into the dht records.
436    PutRecord {
437        record: Record,
438        /// Custom data. Passed back in the out event when the results arrive.
439        query_id: QueryId,
440    },
441
442    /// Response to a `PutRecord`.
443    PutRecordRes {
444        /// Key of the value that was put.
445        key: record_priv::Key,
446        /// Value that was put.
447        value: Vec<u8>,
448        /// Identifier of the request that was made by the remote.
449        request_id: RequestId,
450    },
451}
452
453/// Unique identifier for a request. Must be passed back in order to answer a request from
454/// the remote.
455#[derive(Debug, PartialEq, Eq, Copy, Clone)]
456pub struct RequestId {
457    /// Unique identifier for an incoming connection.
458    connec_unique_id: UniqueConnecId,
459}
460
461/// Unique identifier for a connection.
462#[derive(Debug, Copy, Clone, PartialEq, Eq)]
463struct UniqueConnecId(u64);
464
465impl Handler {
466    pub fn new(
467        protocol_config: ProtocolConfig,
468        idle_timeout: Duration,
469        endpoint: ConnectedPoint,
470        remote_peer_id: PeerId,
471        mode: Mode,
472        connection_id: ConnectionId,
473    ) -> Self {
474        match &endpoint {
475            ConnectedPoint::Dialer { .. } => {
476                log::debug!(
477                    "Operating in {mode}-mode on new outbound connection to {remote_peer_id}"
478                );
479            }
480            ConnectedPoint::Listener { .. } => {
481                log::debug!(
482                    "Operating in {mode}-mode on new inbound connection to {remote_peer_id}"
483                );
484            }
485        }
486
487        #[allow(deprecated)]
488        let keep_alive = KeepAlive::Until(Instant::now() + idle_timeout);
489
490        Handler {
491            protocol_config,
492            mode,
493            idle_timeout,
494            endpoint,
495            remote_peer_id,
496            next_connec_unique_id: UniqueConnecId(0),
497            inbound_substreams: Default::default(),
498            outbound_substreams: Default::default(),
499            num_requested_outbound_streams: 0,
500            pending_messages: Default::default(),
501            keep_alive,
502            protocol_status: None,
503            remote_supported_protocols: Default::default(),
504            connection_id,
505        }
506    }
507
508    fn on_fully_negotiated_outbound(
509        &mut self,
510        FullyNegotiatedOutbound { protocol, info: () }: FullyNegotiatedOutbound<
511            <Self as ConnectionHandler>::OutboundProtocol,
512            <Self as ConnectionHandler>::OutboundOpenInfo,
513        >,
514    ) {
515        if let Some((msg, query_id)) = self.pending_messages.pop_front() {
516            self.outbound_substreams
517                .push(OutboundSubstreamState::PendingSend(protocol, msg, query_id));
518        } else {
519            debug_assert!(false, "Requested outbound stream without message")
520        }
521
522        self.num_requested_outbound_streams -= 1;
523
524        if self.protocol_status.is_none() {
525            // Upon the first successfully negotiated substream, we know that the
526            // remote is configured with the same protocol name and we want
527            // the behaviour to add this peer to the routing table, if possible.
528            self.protocol_status = Some(ProtocolStatus {
529                supported: true,
530                reported: false,
531            });
532        }
533    }
534
535    fn on_fully_negotiated_inbound(
536        &mut self,
537        FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound<
538            <Self as ConnectionHandler>::InboundProtocol,
539            <Self as ConnectionHandler>::InboundOpenInfo,
540        >,
541    ) {
542        // If `self.allow_listening` is false, then we produced a `DeniedUpgrade` and `protocol`
543        // is a `Void`.
544        let protocol = match protocol {
545            future::Either::Left(p) => p,
546            future::Either::Right(p) => void::unreachable(p),
547        };
548
549        if self.protocol_status.is_none() {
550            // Upon the first successfully negotiated substream, we know that the
551            // remote is configured with the same protocol name and we want
552            // the behaviour to add this peer to the routing table, if possible.
553            self.protocol_status = Some(ProtocolStatus {
554                supported: true,
555                reported: false,
556            });
557        }
558
559        if self.inbound_substreams.len() == MAX_NUM_SUBSTREAMS {
560            if let Some(s) = self.inbound_substreams.iter_mut().find(|s| {
561                matches!(
562                    s,
563                    // An inbound substream waiting to be reused.
564                    InboundSubstreamState::WaitingMessage { first: false, .. }
565                )
566            }) {
567                *s = InboundSubstreamState::Cancelled;
568                log::debug!(
569                    "New inbound substream to {:?} exceeds inbound substream limit. \
570                    Removed older substream waiting to be reused.",
571                    self.remote_peer_id,
572                )
573            } else {
574                log::warn!(
575                    "New inbound substream to {:?} exceeds inbound substream limit. \
576                     No older substream waiting to be reused. Dropping new substream.",
577                    self.remote_peer_id,
578                );
579                return;
580            }
581        }
582
583        let connec_unique_id = self.next_connec_unique_id;
584        self.next_connec_unique_id.0 += 1;
585        self.inbound_substreams
586            .push(InboundSubstreamState::WaitingMessage {
587                first: true,
588                connection_id: connec_unique_id,
589                substream: protocol,
590            });
591    }
592
593    fn on_dial_upgrade_error(
594        &mut self,
595        DialUpgradeError {
596            info: (), error, ..
597        }: DialUpgradeError<
598            <Self as ConnectionHandler>::OutboundOpenInfo,
599            <Self as ConnectionHandler>::OutboundProtocol,
600        >,
601    ) {
602        // TODO: cache the fact that the remote doesn't support kademlia at all, so that we don't
603        //       continue trying
604
605        if let Some((_, Some(query_id))) = self.pending_messages.pop_front() {
606            self.outbound_substreams
607                .push(OutboundSubstreamState::ReportError(error.into(), query_id));
608        }
609
610        self.num_requested_outbound_streams -= 1;
611    }
612}
613
614impl ConnectionHandler for Handler {
615    type FromBehaviour = HandlerIn;
616    type ToBehaviour = HandlerEvent;
617    type Error = io::Error; // TODO: better error type?
618    type InboundProtocol = Either<ProtocolConfig, upgrade::DeniedUpgrade>;
619    type OutboundProtocol = ProtocolConfig;
620    type OutboundOpenInfo = ();
621    type InboundOpenInfo = ();
622
623    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
624        match self.mode {
625            Mode::Server => SubstreamProtocol::new(Either::Left(self.protocol_config.clone()), ()),
626            Mode::Client => SubstreamProtocol::new(Either::Right(upgrade::DeniedUpgrade), ()),
627        }
628    }
629
630    fn on_behaviour_event(&mut self, message: HandlerIn) {
631        match message {
632            HandlerIn::Reset(request_id) => {
633                if let Some(state) = self
634                    .inbound_substreams
635                    .iter_mut()
636                    .find(|state| match state {
637                        InboundSubstreamState::WaitingBehaviour(conn_id, _, _) => {
638                            conn_id == &request_id.connec_unique_id
639                        }
640                        _ => false,
641                    })
642                {
643                    state.close();
644                }
645            }
646            HandlerIn::FindNodeReq { key, query_id } => {
647                let msg = KadRequestMsg::FindNode { key };
648                self.pending_messages.push_back((msg, Some(query_id)));
649            }
650            HandlerIn::FindNodeRes {
651                closer_peers,
652                request_id,
653            } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
654            HandlerIn::GetProvidersReq { key, query_id } => {
655                let msg = KadRequestMsg::GetProviders { key };
656                self.pending_messages.push_back((msg, Some(query_id)));
657            }
658            HandlerIn::GetProvidersRes {
659                closer_peers,
660                provider_peers,
661                request_id,
662            } => self.answer_pending_request(
663                request_id,
664                KadResponseMsg::GetProviders {
665                    closer_peers,
666                    provider_peers,
667                },
668            ),
669            HandlerIn::AddProvider { key, provider } => {
670                let msg = KadRequestMsg::AddProvider { key, provider };
671                self.pending_messages.push_back((msg, None));
672            }
673            HandlerIn::GetRecord { key, query_id } => {
674                let msg = KadRequestMsg::GetValue { key };
675                self.pending_messages.push_back((msg, Some(query_id)));
676            }
677            HandlerIn::PutRecord { record, query_id } => {
678                let msg = KadRequestMsg::PutValue { record };
679                self.pending_messages.push_back((msg, Some(query_id)));
680            }
681            HandlerIn::GetRecordRes {
682                record,
683                closer_peers,
684                request_id,
685            } => {
686                self.answer_pending_request(
687                    request_id,
688                    KadResponseMsg::GetValue {
689                        record,
690                        closer_peers,
691                    },
692                );
693            }
694            HandlerIn::PutRecordRes {
695                key,
696                request_id,
697                value,
698            } => {
699                self.answer_pending_request(request_id, KadResponseMsg::PutValue { key, value });
700            }
701            HandlerIn::ReconfigureMode { new_mode } => {
702                let peer = self.remote_peer_id;
703
704                match &self.endpoint {
705                    ConnectedPoint::Dialer { .. } => {
706                        log::debug!(
707                            "Now operating in {new_mode}-mode on outbound connection with {peer}"
708                        )
709                    }
710                    ConnectedPoint::Listener { local_addr, .. } => {
711                        log::debug!("Now operating in {new_mode}-mode on inbound connection with {peer} assuming that one of our external addresses routes to {local_addr}")
712                    }
713                }
714
715                self.mode = new_mode;
716            }
717        }
718    }
719
720    fn connection_keep_alive(&self) -> KeepAlive {
721        self.keep_alive
722    }
723
724    fn poll(
725        &mut self,
726        cx: &mut Context<'_>,
727    ) -> Poll<
728        ConnectionHandlerEvent<
729            Self::OutboundProtocol,
730            Self::OutboundOpenInfo,
731            Self::ToBehaviour,
732            Self::Error,
733        >,
734    > {
735        match &mut self.protocol_status {
736            Some(status) if !status.reported => {
737                status.reported = true;
738                let event = if status.supported {
739                    HandlerEvent::ProtocolConfirmed {
740                        endpoint: self.endpoint.clone(),
741                    }
742                } else {
743                    HandlerEvent::ProtocolNotSupported {
744                        endpoint: self.endpoint.clone(),
745                    }
746                };
747
748                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
749            }
750            _ => {}
751        }
752
753        if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) {
754            return Poll::Ready(event);
755        }
756
757        if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) {
758            return Poll::Ready(event);
759        }
760
761        let num_in_progress_outbound_substreams =
762            self.outbound_substreams.len() + self.num_requested_outbound_streams;
763        if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS
764            && self.num_requested_outbound_streams < self.pending_messages.len()
765        {
766            self.num_requested_outbound_streams += 1;
767            return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
768                protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()),
769            });
770        }
771
772        let no_streams = self.outbound_substreams.is_empty() && self.inbound_substreams.is_empty();
773
774        self.keep_alive = {
775            #[allow(deprecated)]
776            match (no_streams, self.keep_alive) {
777                // No open streams. Preserve the existing idle timeout.
778                (true, k @ KeepAlive::Until(_)) => k,
779                // No open streams. Set idle timeout.
780                (true, _) => KeepAlive::Until(Instant::now() + self.idle_timeout),
781                // Keep alive for open streams.
782                (false, _) => KeepAlive::Yes,
783            }
784        };
785
786        Poll::Pending
787    }
788
789    fn on_connection_event(
790        &mut self,
791        event: ConnectionEvent<
792            Self::InboundProtocol,
793            Self::OutboundProtocol,
794            Self::InboundOpenInfo,
795            Self::OutboundOpenInfo,
796        >,
797    ) {
798        match event {
799            ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
800                self.on_fully_negotiated_outbound(fully_negotiated_outbound)
801            }
802            ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
803                self.on_fully_negotiated_inbound(fully_negotiated_inbound)
804            }
805            ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
806                self.on_dial_upgrade_error(dial_upgrade_error)
807            }
808            ConnectionEvent::AddressChange(_)
809            | ConnectionEvent::ListenUpgradeError(_)
810            | ConnectionEvent::LocalProtocolsChange(_) => {}
811            ConnectionEvent::RemoteProtocolsChange(change) => {
812                let dirty = self.remote_supported_protocols.on_protocols_change(change);
813
814                if dirty {
815                    let remote_supports_our_kademlia_protocols = self
816                        .remote_supported_protocols
817                        .iter()
818                        .any(|p| self.protocol_config.protocol_names().contains(p));
819
820                    self.protocol_status = Some(compute_new_protocol_status(
821                        remote_supports_our_kademlia_protocols,
822                        self.protocol_status,
823                        self.remote_peer_id,
824                        self.connection_id,
825                    ))
826                }
827            }
828        }
829    }
830}
831
832fn compute_new_protocol_status(
833    now_supported: bool,
834    current_status: Option<ProtocolStatus>,
835    remote_peer_id: PeerId,
836    connection_id: ConnectionId,
837) -> ProtocolStatus {
838    let current_status = match current_status {
839        None => {
840            return ProtocolStatus {
841                supported: now_supported,
842                reported: false,
843            }
844        }
845        Some(current) => current,
846    };
847
848    if now_supported == current_status.supported {
849        return ProtocolStatus {
850            supported: now_supported,
851            reported: true,
852        };
853    }
854
855    if now_supported {
856        log::debug!("Remote {remote_peer_id} now supports our kademlia protocol on connection {connection_id}");
857    } else {
858        log::debug!("Remote {remote_peer_id} no longer supports our kademlia protocol on connection {connection_id}");
859    }
860
861    ProtocolStatus {
862        supported: now_supported,
863        reported: false,
864    }
865}
866
867impl Handler {
868    fn answer_pending_request(&mut self, request_id: RequestId, mut msg: KadResponseMsg) {
869        for state in self.inbound_substreams.iter_mut() {
870            match state.try_answer_with(request_id, msg) {
871                Ok(()) => return,
872                Err(m) => {
873                    msg = m;
874                }
875            }
876        }
877
878        debug_assert!(false, "Cannot find inbound substream for {request_id:?}")
879    }
880}
881
882impl futures::Stream for OutboundSubstreamState {
883    type Item = ConnectionHandlerEvent<ProtocolConfig, (), HandlerEvent, io::Error>;
884
885    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
886        let this = self.get_mut();
887
888        loop {
889            match std::mem::replace(this, OutboundSubstreamState::Poisoned) {
890                OutboundSubstreamState::PendingSend(mut substream, msg, query_id) => {
891                    match substream.poll_ready_unpin(cx) {
892                        Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) {
893                            Ok(()) => {
894                                *this = OutboundSubstreamState::PendingFlush(substream, query_id);
895                            }
896                            Err(error) => {
897                                *this = OutboundSubstreamState::Done;
898                                let event = query_id.map(|query_id| {
899                                    ConnectionHandlerEvent::NotifyBehaviour(
900                                        HandlerEvent::QueryError {
901                                            error: HandlerQueryErr::Io(error),
902                                            query_id,
903                                        },
904                                    )
905                                });
906
907                                return Poll::Ready(event);
908                            }
909                        },
910                        Poll::Pending => {
911                            *this = OutboundSubstreamState::PendingSend(substream, msg, query_id);
912                            return Poll::Pending;
913                        }
914                        Poll::Ready(Err(error)) => {
915                            *this = OutboundSubstreamState::Done;
916                            let event = query_id.map(|query_id| {
917                                ConnectionHandlerEvent::NotifyBehaviour(HandlerEvent::QueryError {
918                                    error: HandlerQueryErr::Io(error),
919                                    query_id,
920                                })
921                            });
922
923                            return Poll::Ready(event);
924                        }
925                    }
926                }
927                OutboundSubstreamState::PendingFlush(mut substream, query_id) => {
928                    match substream.poll_flush_unpin(cx) {
929                        Poll::Ready(Ok(())) => {
930                            if let Some(query_id) = query_id {
931                                *this = OutboundSubstreamState::WaitingAnswer(substream, query_id);
932                            } else {
933                                *this = OutboundSubstreamState::Closing(substream);
934                            }
935                        }
936                        Poll::Pending => {
937                            *this = OutboundSubstreamState::PendingFlush(substream, query_id);
938                            return Poll::Pending;
939                        }
940                        Poll::Ready(Err(error)) => {
941                            *this = OutboundSubstreamState::Done;
942                            let event = query_id.map(|query_id| {
943                                ConnectionHandlerEvent::NotifyBehaviour(HandlerEvent::QueryError {
944                                    error: HandlerQueryErr::Io(error),
945                                    query_id,
946                                })
947                            });
948
949                            return Poll::Ready(event);
950                        }
951                    }
952                }
953                OutboundSubstreamState::WaitingAnswer(mut substream, query_id) => {
954                    match substream.poll_next_unpin(cx) {
955                        Poll::Ready(Some(Ok(msg))) => {
956                            *this = OutboundSubstreamState::Closing(substream);
957                            let event = process_kad_response(msg, query_id);
958
959                            return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
960                                event,
961                            )));
962                        }
963                        Poll::Pending => {
964                            *this = OutboundSubstreamState::WaitingAnswer(substream, query_id);
965                            return Poll::Pending;
966                        }
967                        Poll::Ready(Some(Err(error))) => {
968                            *this = OutboundSubstreamState::Done;
969                            let event = HandlerEvent::QueryError {
970                                error: HandlerQueryErr::Io(error),
971                                query_id,
972                            };
973
974                            return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
975                                event,
976                            )));
977                        }
978                        Poll::Ready(None) => {
979                            *this = OutboundSubstreamState::Done;
980                            let event = HandlerEvent::QueryError {
981                                error: HandlerQueryErr::Io(io::ErrorKind::UnexpectedEof.into()),
982                                query_id,
983                            };
984
985                            return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
986                                event,
987                            )));
988                        }
989                    }
990                }
991                OutboundSubstreamState::ReportError(error, query_id) => {
992                    *this = OutboundSubstreamState::Done;
993                    let event = HandlerEvent::QueryError { error, query_id };
994
995                    return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(event)));
996                }
997                OutboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) {
998                    Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None),
999                    Poll::Pending => {
1000                        *this = OutboundSubstreamState::Closing(stream);
1001                        return Poll::Pending;
1002                    }
1003                },
1004                OutboundSubstreamState::Done => {
1005                    *this = OutboundSubstreamState::Done;
1006                    return Poll::Ready(None);
1007                }
1008                OutboundSubstreamState::Poisoned => unreachable!(),
1009            }
1010        }
1011    }
1012}
1013
1014impl futures::Stream for InboundSubstreamState {
1015    type Item = ConnectionHandlerEvent<ProtocolConfig, (), HandlerEvent, io::Error>;
1016
1017    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1018        let this = self.get_mut();
1019
1020        loop {
1021            match std::mem::replace(
1022                this,
1023                Self::Poisoned {
1024                    phantom: PhantomData,
1025                },
1026            ) {
1027                InboundSubstreamState::WaitingMessage {
1028                    first,
1029                    connection_id,
1030                    mut substream,
1031                } => match substream.poll_next_unpin(cx) {
1032                    Poll::Ready(Some(Ok(KadRequestMsg::Ping))) => {
1033                        log::warn!("Kademlia PING messages are unsupported");
1034
1035                        *this = InboundSubstreamState::Closing(substream);
1036                    }
1037                    Poll::Ready(Some(Ok(KadRequestMsg::FindNode { key }))) => {
1038                        *this =
1039                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
1040                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1041                            HandlerEvent::FindNodeReq {
1042                                key,
1043                                request_id: RequestId {
1044                                    connec_unique_id: connection_id,
1045                                },
1046                            },
1047                        )));
1048                    }
1049                    Poll::Ready(Some(Ok(KadRequestMsg::GetProviders { key }))) => {
1050                        *this =
1051                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
1052                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1053                            HandlerEvent::GetProvidersReq {
1054                                key,
1055                                request_id: RequestId {
1056                                    connec_unique_id: connection_id,
1057                                },
1058                            },
1059                        )));
1060                    }
1061                    Poll::Ready(Some(Ok(KadRequestMsg::AddProvider { key, provider }))) => {
1062                        *this = InboundSubstreamState::WaitingMessage {
1063                            first: false,
1064                            connection_id,
1065                            substream,
1066                        };
1067                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1068                            HandlerEvent::AddProvider { key, provider },
1069                        )));
1070                    }
1071                    Poll::Ready(Some(Ok(KadRequestMsg::GetValue { key }))) => {
1072                        *this =
1073                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
1074                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1075                            HandlerEvent::GetRecord {
1076                                key,
1077                                request_id: RequestId {
1078                                    connec_unique_id: connection_id,
1079                                },
1080                            },
1081                        )));
1082                    }
1083                    Poll::Ready(Some(Ok(KadRequestMsg::PutValue { record }))) => {
1084                        *this =
1085                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
1086                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
1087                            HandlerEvent::PutRecord {
1088                                record,
1089                                request_id: RequestId {
1090                                    connec_unique_id: connection_id,
1091                                },
1092                            },
1093                        )));
1094                    }
1095                    Poll::Pending => {
1096                        *this = InboundSubstreamState::WaitingMessage {
1097                            first,
1098                            connection_id,
1099                            substream,
1100                        };
1101                        return Poll::Pending;
1102                    }
1103                    Poll::Ready(None) => {
1104                        return Poll::Ready(None);
1105                    }
1106                    Poll::Ready(Some(Err(e))) => {
1107                        trace!("Inbound substream error: {:?}", e);
1108                        return Poll::Ready(None);
1109                    }
1110                },
1111                InboundSubstreamState::WaitingBehaviour(id, substream, _) => {
1112                    *this = InboundSubstreamState::WaitingBehaviour(
1113                        id,
1114                        substream,
1115                        Some(cx.waker().clone()),
1116                    );
1117
1118                    return Poll::Pending;
1119                }
1120                InboundSubstreamState::PendingSend(id, mut substream, msg) => {
1121                    match substream.poll_ready_unpin(cx) {
1122                        Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) {
1123                            Ok(()) => {
1124                                *this = InboundSubstreamState::PendingFlush(id, substream);
1125                            }
1126                            Err(_) => return Poll::Ready(None),
1127                        },
1128                        Poll::Pending => {
1129                            *this = InboundSubstreamState::PendingSend(id, substream, msg);
1130                            return Poll::Pending;
1131                        }
1132                        Poll::Ready(Err(_)) => return Poll::Ready(None),
1133                    }
1134                }
1135                InboundSubstreamState::PendingFlush(id, mut substream) => {
1136                    match substream.poll_flush_unpin(cx) {
1137                        Poll::Ready(Ok(())) => {
1138                            *this = InboundSubstreamState::WaitingMessage {
1139                                first: false,
1140                                connection_id: id,
1141                                substream,
1142                            };
1143                        }
1144                        Poll::Pending => {
1145                            *this = InboundSubstreamState::PendingFlush(id, substream);
1146                            return Poll::Pending;
1147                        }
1148                        Poll::Ready(Err(_)) => return Poll::Ready(None),
1149                    }
1150                }
1151                InboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) {
1152                    Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None),
1153                    Poll::Pending => {
1154                        *this = InboundSubstreamState::Closing(stream);
1155                        return Poll::Pending;
1156                    }
1157                },
1158                InboundSubstreamState::Poisoned { .. } => unreachable!(),
1159                InboundSubstreamState::Cancelled => return Poll::Ready(None),
1160            }
1161        }
1162    }
1163}
1164
1165/// Process a Kademlia message that's supposed to be a response to one of our requests.
1166fn process_kad_response(event: KadResponseMsg, query_id: QueryId) -> HandlerEvent {
1167    // TODO: must check that the response corresponds to the request
1168    match event {
1169        KadResponseMsg::Pong => {
1170            // We never send out pings.
1171            HandlerEvent::QueryError {
1172                error: HandlerQueryErr::UnexpectedMessage,
1173                query_id,
1174            }
1175        }
1176        KadResponseMsg::FindNode { closer_peers } => HandlerEvent::FindNodeRes {
1177            closer_peers,
1178            query_id,
1179        },
1180        KadResponseMsg::GetProviders {
1181            closer_peers,
1182            provider_peers,
1183        } => HandlerEvent::GetProvidersRes {
1184            closer_peers,
1185            provider_peers,
1186            query_id,
1187        },
1188        KadResponseMsg::GetValue {
1189            record,
1190            closer_peers,
1191        } => HandlerEvent::GetRecordRes {
1192            record,
1193            closer_peers,
1194            query_id,
1195        },
1196        KadResponseMsg::PutValue { key, value, .. } => HandlerEvent::PutRecordRes {
1197            key,
1198            value,
1199            query_id,
1200        },
1201    }
1202}
1203
1204#[cfg(test)]
1205mod tests {
1206    use super::*;
1207    use quickcheck::{Arbitrary, Gen};
1208
1209    impl Arbitrary for ProtocolStatus {
1210        fn arbitrary(g: &mut Gen) -> Self {
1211            Self {
1212                supported: bool::arbitrary(g),
1213                reported: bool::arbitrary(g),
1214            }
1215        }
1216    }
1217
1218    #[test]
1219    fn compute_next_protocol_status_test() {
1220        let _ = env_logger::try_init();
1221
1222        fn prop(now_supported: bool, current: Option<ProtocolStatus>) {
1223            let new = compute_new_protocol_status(
1224                now_supported,
1225                current,
1226                PeerId::random(),
1227                ConnectionId::new_unchecked(0),
1228            );
1229
1230            match current {
1231                None => {
1232                    assert!(!new.reported);
1233                    assert_eq!(new.supported, now_supported);
1234                }
1235                Some(current) => {
1236                    if current.supported == now_supported {
1237                        assert!(new.reported);
1238                    } else {
1239                        assert!(!new.reported);
1240                    }
1241
1242                    assert_eq!(new.supported, now_supported);
1243                }
1244            }
1245        }
1246
1247        quickcheck::quickcheck(prop as fn(_, _))
1248    }
1249}