1pub mod either;
42mod map_in;
43mod map_out;
44pub mod multi;
45mod one_shot;
46mod pending;
47mod select;
48
49use crate::connection::AsStrHashEq;
50pub use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper, UpgradeInfoSend};
51pub use map_in::MapInEvent;
52pub use map_out::MapOutEvent;
53pub use one_shot::{OneShotHandler, OneShotHandlerConfig};
54pub use pending::PendingConnectionHandler;
55pub use select::ConnectionHandlerSelect;
56use smallvec::SmallVec;
57
58use crate::StreamProtocol;
59use core::slice;
60use libp2p_core::Multiaddr;
61use std::collections::{HashMap, HashSet};
62use std::{error, fmt, io, task::Context, task::Poll, time::Duration};
63
64pub trait ConnectionHandler: Send + 'static {
98 type FromBehaviour: fmt::Debug + Send + 'static;
100 type ToBehaviour: fmt::Debug + Send + 'static;
102 type InboundProtocol: InboundUpgradeSend;
104 type OutboundProtocol: OutboundUpgradeSend;
106 type InboundOpenInfo: Send + 'static;
108 type OutboundOpenInfo: Send + 'static;
110
111 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
119
120 fn connection_keep_alive(&self) -> bool {
140 false
141 }
142
143 fn poll(
145 &mut self,
146 cx: &mut Context<'_>,
147 ) -> Poll<
148 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
149 >;
150
151 fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
165 Poll::Ready(None)
166 }
167
168 fn map_in_event<TNewIn, TMap>(self, map: TMap) -> MapInEvent<Self, TNewIn, TMap>
170 where
171 Self: Sized,
172 TMap: Fn(&TNewIn) -> Option<&Self::FromBehaviour>,
173 {
174 MapInEvent::new(self, map)
175 }
176
177 fn map_out_event<TMap, TNewOut>(self, map: TMap) -> MapOutEvent<Self, TMap>
179 where
180 Self: Sized,
181 TMap: FnMut(Self::ToBehaviour) -> TNewOut,
182 {
183 MapOutEvent::new(self, map)
184 }
185
186 fn select<TProto2>(self, other: TProto2) -> ConnectionHandlerSelect<Self, TProto2>
189 where
190 Self: Sized,
191 {
192 ConnectionHandlerSelect::new(self, other)
193 }
194
195 fn on_behaviour_event(&mut self, _event: Self::FromBehaviour);
197
198 fn on_connection_event(
199 &mut self,
200 event: ConnectionEvent<
201 Self::InboundProtocol,
202 Self::OutboundProtocol,
203 Self::InboundOpenInfo,
204 Self::OutboundOpenInfo,
205 >,
206 );
207}
208
209#[non_exhaustive]
212pub enum ConnectionEvent<'a, IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI, OOI> {
213 FullyNegotiatedInbound(FullyNegotiatedInbound<IP, IOI>),
215 FullyNegotiatedOutbound(FullyNegotiatedOutbound<OP, OOI>),
217 AddressChange(AddressChange<'a>),
219 DialUpgradeError(DialUpgradeError<OOI, OP>),
221 ListenUpgradeError(ListenUpgradeError<IOI, IP>),
223 LocalProtocolsChange(ProtocolsChange<'a>),
225 RemoteProtocolsChange(ProtocolsChange<'a>),
227}
228
229impl<'a, IP, OP, IOI, OOI> fmt::Debug for ConnectionEvent<'a, IP, OP, IOI, OOI>
230where
231 IP: InboundUpgradeSend + fmt::Debug,
232 IP::Output: fmt::Debug,
233 IP::Error: fmt::Debug,
234 OP: OutboundUpgradeSend + fmt::Debug,
235 OP::Output: fmt::Debug,
236 OP::Error: fmt::Debug,
237 IOI: fmt::Debug,
238 OOI: fmt::Debug,
239{
240 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241 match self {
242 ConnectionEvent::FullyNegotiatedInbound(v) => {
243 f.debug_tuple("FullyNegotiatedInbound").field(v).finish()
244 }
245 ConnectionEvent::FullyNegotiatedOutbound(v) => {
246 f.debug_tuple("FullyNegotiatedOutbound").field(v).finish()
247 }
248 ConnectionEvent::AddressChange(v) => f.debug_tuple("AddressChange").field(v).finish(),
249 ConnectionEvent::DialUpgradeError(v) => {
250 f.debug_tuple("DialUpgradeError").field(v).finish()
251 }
252 ConnectionEvent::ListenUpgradeError(v) => {
253 f.debug_tuple("ListenUpgradeError").field(v).finish()
254 }
255 ConnectionEvent::LocalProtocolsChange(v) => {
256 f.debug_tuple("LocalProtocolsChange").field(v).finish()
257 }
258 ConnectionEvent::RemoteProtocolsChange(v) => {
259 f.debug_tuple("RemoteProtocolsChange").field(v).finish()
260 }
261 }
262 }
263}
264
265impl<'a, IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI, OOI>
266 ConnectionEvent<'a, IP, OP, IOI, OOI>
267{
268 pub fn is_outbound(&self) -> bool {
270 match self {
271 ConnectionEvent::DialUpgradeError(_) | ConnectionEvent::FullyNegotiatedOutbound(_) => {
272 true
273 }
274 ConnectionEvent::FullyNegotiatedInbound(_)
275 | ConnectionEvent::AddressChange(_)
276 | ConnectionEvent::LocalProtocolsChange(_)
277 | ConnectionEvent::RemoteProtocolsChange(_)
278 | ConnectionEvent::ListenUpgradeError(_) => false,
279 }
280 }
281
282 pub fn is_inbound(&self) -> bool {
284 match self {
285 ConnectionEvent::FullyNegotiatedInbound(_) | ConnectionEvent::ListenUpgradeError(_) => {
286 true
287 }
288 ConnectionEvent::FullyNegotiatedOutbound(_)
289 | ConnectionEvent::AddressChange(_)
290 | ConnectionEvent::LocalProtocolsChange(_)
291 | ConnectionEvent::RemoteProtocolsChange(_)
292 | ConnectionEvent::DialUpgradeError(_) => false,
293 }
294 }
295}
296
297#[derive(Debug)]
306pub struct FullyNegotiatedInbound<IP: InboundUpgradeSend, IOI> {
307 pub protocol: IP::Output,
308 pub info: IOI,
309}
310
311#[derive(Debug)]
316pub struct FullyNegotiatedOutbound<OP: OutboundUpgradeSend, OOI> {
317 pub protocol: OP::Output,
318 pub info: OOI,
319}
320
321#[derive(Debug)]
323pub struct AddressChange<'a> {
324 pub new_address: &'a Multiaddr,
325}
326
327#[derive(Debug, Clone)]
329pub enum ProtocolsChange<'a> {
330 Added(ProtocolsAdded<'a>),
331 Removed(ProtocolsRemoved<'a>),
332}
333
334impl<'a> ProtocolsChange<'a> {
335 pub(crate) fn from_initial_protocols<'b, T: AsRef<str> + 'b>(
337 new_protocols: impl IntoIterator<Item = &'b T>,
338 buffer: &'a mut Vec<StreamProtocol>,
339 ) -> Self {
340 buffer.clear();
341 buffer.extend(
342 new_protocols
343 .into_iter()
344 .filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok()),
345 );
346
347 ProtocolsChange::Added(ProtocolsAdded {
348 protocols: buffer.iter(),
349 })
350 }
351
352 pub(crate) fn add(
356 existing_protocols: &HashSet<StreamProtocol>,
357 to_add: HashSet<StreamProtocol>,
358 buffer: &'a mut Vec<StreamProtocol>,
359 ) -> Option<Self> {
360 buffer.clear();
361 buffer.extend(
362 to_add
363 .into_iter()
364 .filter(|i| !existing_protocols.contains(i)),
365 );
366
367 if buffer.is_empty() {
368 return None;
369 }
370
371 Some(Self::Added(ProtocolsAdded {
372 protocols: buffer.iter(),
373 }))
374 }
375
376 pub(crate) fn remove(
380 existing_protocols: &mut HashSet<StreamProtocol>,
381 to_remove: HashSet<StreamProtocol>,
382 buffer: &'a mut Vec<StreamProtocol>,
383 ) -> Option<Self> {
384 buffer.clear();
385 buffer.extend(
386 to_remove
387 .into_iter()
388 .filter_map(|i| existing_protocols.take(&i)),
389 );
390
391 if buffer.is_empty() {
392 return None;
393 }
394
395 Some(Self::Removed(ProtocolsRemoved {
396 protocols: buffer.iter(),
397 }))
398 }
399
400 pub(crate) fn from_full_sets<T: AsRef<str>>(
402 existing_protocols: &mut HashMap<AsStrHashEq<T>, bool>,
403 new_protocols: impl IntoIterator<Item = T>,
404 buffer: &'a mut Vec<StreamProtocol>,
405 ) -> SmallVec<[Self; 2]> {
406 buffer.clear();
407
408 for v in existing_protocols.values_mut() {
410 *v = false;
411 }
412
413 let mut new_protocol_count = 0; for new_protocol in new_protocols {
415 existing_protocols
416 .entry(AsStrHashEq(new_protocol))
417 .and_modify(|v| *v = true) .or_insert_with_key(|k| {
419 buffer.extend(StreamProtocol::try_from_owned(k.0.as_ref().to_owned()).ok());
421 true
422 });
423 new_protocol_count += 1;
424 }
425
426 if new_protocol_count == existing_protocols.len() && buffer.is_empty() {
427 return SmallVec::new();
428 }
429
430 let num_new_protocols = buffer.len();
431 existing_protocols.retain(|p, &mut is_supported| {
434 if !is_supported {
435 buffer.extend(StreamProtocol::try_from_owned(p.0.as_ref().to_owned()).ok());
436 }
437
438 is_supported
439 });
440
441 let (added, removed) = buffer.split_at(num_new_protocols);
442 let mut changes = SmallVec::new();
443 if !added.is_empty() {
444 changes.push(ProtocolsChange::Added(ProtocolsAdded {
445 protocols: added.iter(),
446 }));
447 }
448 if !removed.is_empty() {
449 changes.push(ProtocolsChange::Removed(ProtocolsRemoved {
450 protocols: removed.iter(),
451 }));
452 }
453 changes
454 }
455}
456
457#[derive(Debug, Clone)]
459pub struct ProtocolsAdded<'a> {
460 pub(crate) protocols: slice::Iter<'a, StreamProtocol>,
461}
462
463#[derive(Debug, Clone)]
465pub struct ProtocolsRemoved<'a> {
466 pub(crate) protocols: slice::Iter<'a, StreamProtocol>,
467}
468
469impl<'a> Iterator for ProtocolsAdded<'a> {
470 type Item = &'a StreamProtocol;
471 fn next(&mut self) -> Option<Self::Item> {
472 self.protocols.next()
473 }
474}
475
476impl<'a> Iterator for ProtocolsRemoved<'a> {
477 type Item = &'a StreamProtocol;
478 fn next(&mut self) -> Option<Self::Item> {
479 self.protocols.next()
480 }
481}
482
483#[derive(Debug)]
486pub struct DialUpgradeError<OOI, OP: OutboundUpgradeSend> {
487 pub info: OOI,
488 pub error: StreamUpgradeError<OP::Error>,
489}
490
491#[derive(Debug)]
494pub struct ListenUpgradeError<IOI, IP: InboundUpgradeSend> {
495 pub info: IOI,
496 pub error: IP::Error,
497}
498
499#[derive(Copy, Clone, Debug, PartialEq, Eq)]
505pub struct SubstreamProtocol<TUpgrade, TInfo> {
506 upgrade: TUpgrade,
507 info: TInfo,
508 timeout: Duration,
509}
510
511impl<TUpgrade, TInfo> SubstreamProtocol<TUpgrade, TInfo> {
512 pub fn new(upgrade: TUpgrade, info: TInfo) -> Self {
517 SubstreamProtocol {
518 upgrade,
519 info,
520 timeout: Duration::from_secs(10),
521 }
522 }
523
524 pub fn map_upgrade<U, F>(self, f: F) -> SubstreamProtocol<U, TInfo>
526 where
527 F: FnOnce(TUpgrade) -> U,
528 {
529 SubstreamProtocol {
530 upgrade: f(self.upgrade),
531 info: self.info,
532 timeout: self.timeout,
533 }
534 }
535
536 pub fn map_info<U, F>(self, f: F) -> SubstreamProtocol<TUpgrade, U>
538 where
539 F: FnOnce(TInfo) -> U,
540 {
541 SubstreamProtocol {
542 upgrade: self.upgrade,
543 info: f(self.info),
544 timeout: self.timeout,
545 }
546 }
547
548 pub fn with_timeout(mut self, timeout: Duration) -> Self {
550 self.timeout = timeout;
551 self
552 }
553
554 pub fn upgrade(&self) -> &TUpgrade {
556 &self.upgrade
557 }
558
559 pub fn info(&self) -> &TInfo {
561 &self.info
562 }
563
564 pub fn timeout(&self) -> &Duration {
566 &self.timeout
567 }
568
569 pub fn into_upgrade(self) -> (TUpgrade, TInfo) {
571 (self.upgrade, self.info)
572 }
573}
574
575#[derive(Debug, Clone, PartialEq, Eq)]
577#[non_exhaustive]
578pub enum ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, TCustom> {
579 OutboundSubstreamRequest {
581 protocol: SubstreamProtocol<TConnectionUpgrade, TOutboundOpenInfo>,
583 },
584 ReportRemoteProtocols(ProtocolSupport),
586
587 NotifyBehaviour(TCustom),
589}
590
591#[derive(Debug, Clone, PartialEq, Eq)]
592pub enum ProtocolSupport {
593 Added(HashSet<StreamProtocol>),
595 Removed(HashSet<StreamProtocol>),
597}
598
599impl<TConnectionUpgrade, TOutboundOpenInfo, TCustom>
601 ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, TCustom>
602{
603 pub fn map_outbound_open_info<F, I>(
606 self,
607 map: F,
608 ) -> ConnectionHandlerEvent<TConnectionUpgrade, I, TCustom>
609 where
610 F: FnOnce(TOutboundOpenInfo) -> I,
611 {
612 match self {
613 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
614 ConnectionHandlerEvent::OutboundSubstreamRequest {
615 protocol: protocol.map_info(map),
616 }
617 }
618 ConnectionHandlerEvent::NotifyBehaviour(val) => {
619 ConnectionHandlerEvent::NotifyBehaviour(val)
620 }
621 ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
622 ConnectionHandlerEvent::ReportRemoteProtocols(support)
623 }
624 }
625 }
626
627 pub fn map_protocol<F, I>(self, map: F) -> ConnectionHandlerEvent<I, TOutboundOpenInfo, TCustom>
630 where
631 F: FnOnce(TConnectionUpgrade) -> I,
632 {
633 match self {
634 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
635 ConnectionHandlerEvent::OutboundSubstreamRequest {
636 protocol: protocol.map_upgrade(map),
637 }
638 }
639 ConnectionHandlerEvent::NotifyBehaviour(val) => {
640 ConnectionHandlerEvent::NotifyBehaviour(val)
641 }
642 ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
643 ConnectionHandlerEvent::ReportRemoteProtocols(support)
644 }
645 }
646 }
647
648 pub fn map_custom<F, I>(
650 self,
651 map: F,
652 ) -> ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, I>
653 where
654 F: FnOnce(TCustom) -> I,
655 {
656 match self {
657 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
658 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }
659 }
660 ConnectionHandlerEvent::NotifyBehaviour(val) => {
661 ConnectionHandlerEvent::NotifyBehaviour(map(val))
662 }
663 ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
664 ConnectionHandlerEvent::ReportRemoteProtocols(support)
665 }
666 }
667 }
668}
669
670#[derive(Debug)]
672pub enum StreamUpgradeError<TUpgrErr> {
673 Timeout,
675 Apply(TUpgrErr),
677 NegotiationFailed,
679 Io(io::Error),
681}
682
683impl<TUpgrErr> StreamUpgradeError<TUpgrErr> {
684 pub fn map_upgrade_err<F, E>(self, f: F) -> StreamUpgradeError<E>
686 where
687 F: FnOnce(TUpgrErr) -> E,
688 {
689 match self {
690 StreamUpgradeError::Timeout => StreamUpgradeError::Timeout,
691 StreamUpgradeError::Apply(e) => StreamUpgradeError::Apply(f(e)),
692 StreamUpgradeError::NegotiationFailed => StreamUpgradeError::NegotiationFailed,
693 StreamUpgradeError::Io(e) => StreamUpgradeError::Io(e),
694 }
695 }
696}
697
698impl<TUpgrErr> fmt::Display for StreamUpgradeError<TUpgrErr>
699where
700 TUpgrErr: error::Error + 'static,
701{
702 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
703 match self {
704 StreamUpgradeError::Timeout => {
705 write!(f, "Timeout error while opening a substream")
706 }
707 StreamUpgradeError::Apply(err) => {
708 write!(f, "Apply: ")?;
709 crate::print_error_chain(f, err)
710 }
711 StreamUpgradeError::NegotiationFailed => {
712 write!(f, "no protocols could be agreed upon")
713 }
714 StreamUpgradeError::Io(e) => {
715 write!(f, "IO error: ")?;
716 crate::print_error_chain(f, e)
717 }
718 }
719 }
720}
721
722impl<TUpgrErr> error::Error for StreamUpgradeError<TUpgrErr>
723where
724 TUpgrErr: error::Error + 'static,
725{
726 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
727 None
728 }
729}
730
731#[cfg(test)]
732mod test {
733 use super::*;
734
735 fn protocol_set_of(s: &'static str) -> HashSet<StreamProtocol> {
736 s.split_whitespace()
737 .map(|p| StreamProtocol::try_from_owned(format!("/{p}")).unwrap())
738 .collect()
739 }
740
741 fn test_remove(
742 existing: &mut HashSet<StreamProtocol>,
743 to_remove: HashSet<StreamProtocol>,
744 ) -> HashSet<StreamProtocol> {
745 ProtocolsChange::remove(existing, to_remove, &mut Vec::new())
746 .into_iter()
747 .flat_map(|c| match c {
748 ProtocolsChange::Added(_) => panic!("unexpected added"),
749 ProtocolsChange::Removed(r) => r.cloned(),
750 })
751 .collect::<HashSet<_>>()
752 }
753
754 #[test]
755 fn test_protocol_remove_subset() {
756 let mut existing = protocol_set_of("a b c");
757 let to_remove = protocol_set_of("a b");
758
759 let change = test_remove(&mut existing, to_remove);
760
761 assert_eq!(existing, protocol_set_of("c"));
762 assert_eq!(change, protocol_set_of("a b"));
763 }
764
765 #[test]
766 fn test_protocol_remove_all() {
767 let mut existing = protocol_set_of("a b c");
768 let to_remove = protocol_set_of("a b c");
769
770 let change = test_remove(&mut existing, to_remove);
771
772 assert_eq!(existing, protocol_set_of(""));
773 assert_eq!(change, protocol_set_of("a b c"));
774 }
775
776 #[test]
777 fn test_protocol_remove_superset() {
778 let mut existing = protocol_set_of("a b c");
779 let to_remove = protocol_set_of("a b c d");
780
781 let change = test_remove(&mut existing, to_remove);
782
783 assert_eq!(existing, protocol_set_of(""));
784 assert_eq!(change, protocol_set_of("a b c"));
785 }
786
787 #[test]
788 fn test_protocol_remove_none() {
789 let mut existing = protocol_set_of("a b c");
790 let to_remove = protocol_set_of("d");
791
792 let change = test_remove(&mut existing, to_remove);
793
794 assert_eq!(existing, protocol_set_of("a b c"));
795 assert_eq!(change, protocol_set_of(""));
796 }
797
798 #[test]
799 fn test_protocol_remove_none_from_empty() {
800 let mut existing = protocol_set_of("");
801 let to_remove = protocol_set_of("d");
802
803 let change = test_remove(&mut existing, to_remove);
804
805 assert_eq!(existing, protocol_set_of(""));
806 assert_eq!(change, protocol_set_of(""));
807 }
808
809 fn test_from_full_sets(
810 existing: HashSet<StreamProtocol>,
811 new: HashSet<StreamProtocol>,
812 ) -> [HashSet<StreamProtocol>; 2] {
813 let mut buffer = Vec::new();
814 let mut existing = existing
815 .iter()
816 .map(|p| (AsStrHashEq(p.as_ref()), true))
817 .collect::<HashMap<_, _>>();
818
819 let changes = ProtocolsChange::from_full_sets(
820 &mut existing,
821 new.iter().map(AsRef::as_ref),
822 &mut buffer,
823 );
824
825 let mut added_changes = HashSet::new();
826 let mut removed_changes = HashSet::new();
827
828 for change in changes {
829 match change {
830 ProtocolsChange::Added(a) => {
831 added_changes.extend(a.cloned());
832 }
833 ProtocolsChange::Removed(r) => {
834 removed_changes.extend(r.cloned());
835 }
836 }
837 }
838
839 [removed_changes, added_changes]
840 }
841
842 #[test]
843 fn test_from_full_stes_subset() {
844 let existing = protocol_set_of("a b c");
845 let new = protocol_set_of("a b");
846
847 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
848
849 assert_eq!(added_changes, protocol_set_of(""));
850 assert_eq!(removed_changes, protocol_set_of("c"));
851 }
852
853 #[test]
854 fn test_from_full_sets_superset() {
855 let existing = protocol_set_of("a b");
856 let new = protocol_set_of("a b c");
857
858 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
859
860 assert_eq!(added_changes, protocol_set_of("c"));
861 assert_eq!(removed_changes, protocol_set_of(""));
862 }
863
864 #[test]
865 fn test_from_full_sets_intersection() {
866 let existing = protocol_set_of("a b c");
867 let new = protocol_set_of("b c d");
868
869 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
870
871 assert_eq!(added_changes, protocol_set_of("d"));
872 assert_eq!(removed_changes, protocol_set_of("a"));
873 }
874
875 #[test]
876 fn test_from_full_sets_disjoint() {
877 let existing = protocol_set_of("a b c");
878 let new = protocol_set_of("d e f");
879
880 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
881
882 assert_eq!(added_changes, protocol_set_of("d e f"));
883 assert_eq!(removed_changes, protocol_set_of("a b c"));
884 }
885
886 #[test]
887 fn test_from_full_sets_empty() {
888 let existing = protocol_set_of("");
889 let new = protocol_set_of("");
890
891 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
892
893 assert_eq!(added_changes, protocol_set_of(""));
894 assert_eq!(removed_changes, protocol_set_of(""));
895 }
896}