1use crate::handler::{
22 ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
23 FullyNegotiatedInbound, FullyNegotiatedOutbound, KeepAlive, StreamUpgradeError,
24 SubstreamProtocol,
25};
26use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
27use instant::Instant;
28use smallvec::SmallVec;
29use std::{error, fmt::Debug, task::Context, task::Poll, time::Duration};
30
31pub struct OneShotHandler<TInbound, TOutbound, TEvent>
34where
35 TOutbound: OutboundUpgradeSend,
36{
37 listen_protocol: SubstreamProtocol<TInbound, ()>,
39 pending_error: Option<StreamUpgradeError<<TOutbound as OutboundUpgradeSend>::Error>>,
41 events_out: SmallVec<[TEvent; 4]>,
43 dial_queue: SmallVec<[TOutbound; 4]>,
45 dial_negotiated: u32,
47 keep_alive: KeepAlive,
49 config: OneShotHandlerConfig,
51}
52
53impl<TInbound, TOutbound, TEvent> OneShotHandler<TInbound, TOutbound, TEvent>
54where
55 TOutbound: OutboundUpgradeSend,
56{
57 pub fn new(
59 listen_protocol: SubstreamProtocol<TInbound, ()>,
60 config: OneShotHandlerConfig,
61 ) -> Self {
62 OneShotHandler {
63 listen_protocol,
64 pending_error: None,
65 events_out: SmallVec::new(),
66 dial_queue: SmallVec::new(),
67 dial_negotiated: 0,
68 keep_alive: KeepAlive::Yes,
69 config,
70 }
71 }
72
73 pub fn pending_requests(&self) -> u32 {
75 self.dial_negotiated + self.dial_queue.len() as u32
76 }
77
78 pub fn listen_protocol_ref(&self) -> &SubstreamProtocol<TInbound, ()> {
83 &self.listen_protocol
84 }
85
86 pub fn listen_protocol_mut(&mut self) -> &mut SubstreamProtocol<TInbound, ()> {
91 &mut self.listen_protocol
92 }
93
94 pub fn send_request(&mut self, upgrade: TOutbound) {
96 self.keep_alive = KeepAlive::Yes;
97 self.dial_queue.push(upgrade);
98 }
99}
100
101impl<TInbound, TOutbound, TEvent> Default for OneShotHandler<TInbound, TOutbound, TEvent>
102where
103 TOutbound: OutboundUpgradeSend,
104 TInbound: InboundUpgradeSend + Default,
105{
106 fn default() -> Self {
107 OneShotHandler::new(
108 SubstreamProtocol::new(Default::default(), ()),
109 OneShotHandlerConfig::default(),
110 )
111 }
112}
113
114impl<TInbound, TOutbound, TEvent> ConnectionHandler for OneShotHandler<TInbound, TOutbound, TEvent>
115where
116 TInbound: InboundUpgradeSend + Send + 'static,
117 TOutbound: Debug + OutboundUpgradeSend,
118 TInbound::Output: Into<TEvent>,
119 TOutbound::Output: Into<TEvent>,
120 TOutbound::Error: error::Error + Send + 'static,
121 SubstreamProtocol<TInbound, ()>: Clone,
122 TEvent: Debug + Send + 'static,
123{
124 type FromBehaviour = TOutbound;
125 type ToBehaviour = TEvent;
126 type Error = StreamUpgradeError<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>;
127 type InboundProtocol = TInbound;
128 type OutboundProtocol = TOutbound;
129 type OutboundOpenInfo = ();
130 type InboundOpenInfo = ();
131
132 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
133 self.listen_protocol.clone()
134 }
135
136 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
137 self.send_request(event);
138 }
139
140 fn connection_keep_alive(&self) -> KeepAlive {
141 self.keep_alive
142 }
143
144 #[allow(deprecated)]
145 fn poll(
146 &mut self,
147 _: &mut Context<'_>,
148 ) -> Poll<
149 ConnectionHandlerEvent<
150 Self::OutboundProtocol,
151 Self::OutboundOpenInfo,
152 Self::ToBehaviour,
153 Self::Error,
154 >,
155 > {
156 if let Some(err) = self.pending_error.take() {
157 #[allow(deprecated)]
158 return Poll::Ready(ConnectionHandlerEvent::Close(err));
159 }
160
161 if !self.events_out.is_empty() {
162 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
163 self.events_out.remove(0),
164 ));
165 } else {
166 self.events_out.shrink_to_fit();
167 }
168
169 if !self.dial_queue.is_empty() {
170 if self.dial_negotiated < self.config.max_dial_negotiated {
171 self.dial_negotiated += 1;
172 let upgrade = self.dial_queue.remove(0);
173 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
174 protocol: SubstreamProtocol::new(upgrade, ())
175 .with_timeout(self.config.outbound_substream_timeout),
176 });
177 }
178 } else {
179 self.dial_queue.shrink_to_fit();
180
181 #[allow(deprecated)]
182 if self.dial_negotiated == 0 && self.keep_alive.is_yes() {
183 self.keep_alive = KeepAlive::Until(Instant::now() + self.config.keep_alive_timeout);
184 }
185 }
186
187 Poll::Pending
188 }
189
190 fn on_connection_event(
191 &mut self,
192 event: ConnectionEvent<
193 Self::InboundProtocol,
194 Self::OutboundProtocol,
195 Self::InboundOpenInfo,
196 Self::OutboundOpenInfo,
197 >,
198 ) {
199 match event {
200 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
201 protocol: out,
202 ..
203 }) => {
204 #[allow(deprecated)]
206 if !self.keep_alive.is_yes() {
207 self.keep_alive =
208 KeepAlive::Until(Instant::now() + self.config.keep_alive_timeout);
209 }
210
211 self.events_out.push(out.into());
212 }
213 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
214 protocol: out,
215 ..
216 }) => {
217 self.dial_negotiated -= 1;
218 self.events_out.push(out.into());
219 }
220 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
221 if self.pending_error.is_none() {
222 log::debug!("DialUpgradeError: {error}");
223 self.keep_alive = KeepAlive::No;
224 }
225 }
226 ConnectionEvent::AddressChange(_)
227 | ConnectionEvent::ListenUpgradeError(_)
228 | ConnectionEvent::LocalProtocolsChange(_)
229 | ConnectionEvent::RemoteProtocolsChange(_) => {}
230 }
231 }
232}
233
234#[derive(Debug)]
236pub struct OneShotHandlerConfig {
237 #[deprecated(
239 note = "Set a global idle connection timeout via `SwarmBuilder::idle_connection_timeout` instead."
240 )]
241 pub keep_alive_timeout: Duration,
242 pub outbound_substream_timeout: Duration,
244 pub max_dial_negotiated: u32,
246}
247
248impl Default for OneShotHandlerConfig {
249 #[allow(deprecated)]
250 fn default() -> Self {
251 OneShotHandlerConfig {
252 keep_alive_timeout: Duration::from_secs(10),
253 outbound_substream_timeout: Duration::from_secs(10),
254 max_dial_negotiated: 8,
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 use futures::executor::block_on;
264 use futures::future::poll_fn;
265 use libp2p_core::upgrade::DeniedUpgrade;
266 use void::Void;
267
268 #[test]
269 #[allow(deprecated)]
270 fn do_not_keep_idle_connection_alive() {
271 let mut handler: OneShotHandler<_, DeniedUpgrade, Void> = OneShotHandler::new(
272 SubstreamProtocol::new(DeniedUpgrade {}, ()),
273 Default::default(),
274 );
275
276 block_on(poll_fn(|cx| loop {
277 if handler.poll(cx).is_pending() {
278 return Poll::Ready(());
279 }
280 }));
281
282 assert!(matches!(
283 handler.connection_keep_alive(),
284 KeepAlive::Until(_)
285 ));
286 }
287}