1use crate::handler::{
22 ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
23 FullyNegotiatedInbound, FullyNegotiatedOutbound, SubstreamProtocol,
24};
25use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
26use crate::StreamUpgradeError;
27use smallvec::SmallVec;
28use std::{error, fmt::Debug, task::Context, task::Poll, time::Duration};
29
30pub struct OneShotHandler<TInbound, TOutbound, TEvent>
33where
34 TOutbound: OutboundUpgradeSend,
35{
36 listen_protocol: SubstreamProtocol<TInbound, ()>,
38 events_out: SmallVec<[Result<TEvent, StreamUpgradeError<TOutbound::Error>>; 4]>,
40 dial_queue: SmallVec<[TOutbound; 4]>,
42 dial_negotiated: u32,
44 config: OneShotHandlerConfig,
46}
47
48impl<TInbound, TOutbound, TEvent> OneShotHandler<TInbound, TOutbound, TEvent>
49where
50 TOutbound: OutboundUpgradeSend,
51{
52 pub fn new(
54 listen_protocol: SubstreamProtocol<TInbound, ()>,
55 config: OneShotHandlerConfig,
56 ) -> Self {
57 OneShotHandler {
58 listen_protocol,
59 events_out: SmallVec::new(),
60 dial_queue: SmallVec::new(),
61 dial_negotiated: 0,
62 config,
63 }
64 }
65
66 pub fn pending_requests(&self) -> u32 {
68 self.dial_negotiated + self.dial_queue.len() as u32
69 }
70
71 pub fn listen_protocol_ref(&self) -> &SubstreamProtocol<TInbound, ()> {
76 &self.listen_protocol
77 }
78
79 pub fn listen_protocol_mut(&mut self) -> &mut SubstreamProtocol<TInbound, ()> {
84 &mut self.listen_protocol
85 }
86
87 pub fn send_request(&mut self, upgrade: TOutbound) {
89 self.dial_queue.push(upgrade);
90 }
91}
92
93impl<TInbound, TOutbound, TEvent> Default for OneShotHandler<TInbound, TOutbound, TEvent>
94where
95 TOutbound: OutboundUpgradeSend,
96 TInbound: InboundUpgradeSend + Default,
97{
98 fn default() -> Self {
99 OneShotHandler::new(
100 SubstreamProtocol::new(Default::default(), ()),
101 OneShotHandlerConfig::default(),
102 )
103 }
104}
105
106impl<TInbound, TOutbound, TEvent> ConnectionHandler for OneShotHandler<TInbound, TOutbound, TEvent>
107where
108 TInbound: InboundUpgradeSend + Send + 'static,
109 TOutbound: Debug + OutboundUpgradeSend,
110 TInbound::Output: Into<TEvent>,
111 TOutbound::Output: Into<TEvent>,
112 TOutbound::Error: error::Error + Send + 'static,
113 SubstreamProtocol<TInbound, ()>: Clone,
114 TEvent: Debug + Send + 'static,
115{
116 type FromBehaviour = TOutbound;
117 type ToBehaviour = Result<TEvent, StreamUpgradeError<TOutbound::Error>>;
118 type InboundProtocol = TInbound;
119 type OutboundProtocol = TOutbound;
120 type OutboundOpenInfo = ();
121 type InboundOpenInfo = ();
122
123 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
124 self.listen_protocol.clone()
125 }
126
127 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
128 self.send_request(event);
129 }
130
131 fn poll(
132 &mut self,
133 _: &mut Context<'_>,
134 ) -> Poll<
135 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
136 > {
137 if !self.events_out.is_empty() {
138 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
139 self.events_out.remove(0),
140 ));
141 } else {
142 self.events_out.shrink_to_fit();
143 }
144
145 if !self.dial_queue.is_empty() {
146 if self.dial_negotiated < self.config.max_dial_negotiated {
147 self.dial_negotiated += 1;
148 let upgrade = self.dial_queue.remove(0);
149 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
150 protocol: SubstreamProtocol::new(upgrade, ())
151 .with_timeout(self.config.outbound_substream_timeout),
152 });
153 }
154 } else {
155 self.dial_queue.shrink_to_fit();
156 }
157
158 Poll::Pending
159 }
160
161 fn on_connection_event(
162 &mut self,
163 event: ConnectionEvent<
164 Self::InboundProtocol,
165 Self::OutboundProtocol,
166 Self::InboundOpenInfo,
167 Self::OutboundOpenInfo,
168 >,
169 ) {
170 match event {
171 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
172 protocol: out,
173 ..
174 }) => {
175 self.events_out.push(Ok(out.into()));
176 }
177 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
178 protocol: out,
179 ..
180 }) => {
181 self.dial_negotiated -= 1;
182 self.events_out.push(Ok(out.into()));
183 }
184 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
185 self.events_out.push(Err(error));
186 }
187 ConnectionEvent::AddressChange(_)
188 | ConnectionEvent::ListenUpgradeError(_)
189 | ConnectionEvent::LocalProtocolsChange(_)
190 | ConnectionEvent::RemoteProtocolsChange(_) => {}
191 }
192 }
193}
194
195#[derive(Debug)]
197pub struct OneShotHandlerConfig {
198 pub outbound_substream_timeout: Duration,
200 pub max_dial_negotiated: u32,
202}
203
204impl Default for OneShotHandlerConfig {
205 fn default() -> Self {
206 OneShotHandlerConfig {
207 outbound_substream_timeout: Duration::from_secs(10),
208 max_dial_negotiated: 8,
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 use futures::executor::block_on;
218 use futures::future::poll_fn;
219 use libp2p_core::upgrade::DeniedUpgrade;
220 use void::Void;
221
222 #[test]
223 fn do_not_keep_idle_connection_alive() {
224 let mut handler: OneShotHandler<_, DeniedUpgrade, Void> = OneShotHandler::new(
225 SubstreamProtocol::new(DeniedUpgrade {}, ()),
226 Default::default(),
227 );
228
229 block_on(poll_fn(|cx| loop {
230 if handler.poll(cx).is_pending() {
231 return Poll::Ready(());
232 }
233 }));
234
235 assert!(!handler.connection_keep_alive());
236 }
237}