1use crate::upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeError};
22use crate::{connection::ConnectedPoint, Negotiated};
23use futures::{future::Either, prelude::*};
24use log::debug;
25use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture};
26use std::{mem, pin::Pin, task::Context, task::Poll};
27
28pub(crate) use multistream_select::Version;
29
30pub(crate) fn apply<C, U>(
33 conn: C,
34 up: U,
35 cp: ConnectedPoint,
36 v: Version,
37) -> Either<InboundUpgradeApply<C, U>, OutboundUpgradeApply<C, U>>
38where
39 C: AsyncRead + AsyncWrite + Unpin,
40 U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
41{
42 match cp {
43 ConnectedPoint::Dialer { role_override, .. } if role_override.is_dialer() => {
44 Either::Right(apply_outbound(conn, up, v))
45 }
46 _ => Either::Left(apply_inbound(conn, up)),
47 }
48}
49
50pub(crate) fn apply_inbound<C, U>(conn: C, up: U) -> InboundUpgradeApply<C, U>
52where
53 C: AsyncRead + AsyncWrite + Unpin,
54 U: InboundConnectionUpgrade<Negotiated<C>>,
55{
56 InboundUpgradeApply {
57 inner: InboundUpgradeApplyState::Init {
58 future: multistream_select::listener_select_proto(conn, up.protocol_info()),
59 upgrade: up,
60 },
61 }
62}
63
64pub(crate) fn apply_outbound<C, U>(conn: C, up: U, v: Version) -> OutboundUpgradeApply<C, U>
66where
67 C: AsyncRead + AsyncWrite + Unpin,
68 U: OutboundConnectionUpgrade<Negotiated<C>>,
69{
70 OutboundUpgradeApply {
71 inner: OutboundUpgradeApplyState::Init {
72 future: multistream_select::dialer_select_proto(conn, up.protocol_info(), v),
73 upgrade: up,
74 },
75 }
76}
77
78pub struct InboundUpgradeApply<C, U>
80where
81 C: AsyncRead + AsyncWrite + Unpin,
82 U: InboundConnectionUpgrade<Negotiated<C>>,
83{
84 inner: InboundUpgradeApplyState<C, U>,
85}
86
87#[allow(clippy::large_enum_variant)]
88enum InboundUpgradeApplyState<C, U>
89where
90 C: AsyncRead + AsyncWrite + Unpin,
91 U: InboundConnectionUpgrade<Negotiated<C>>,
92{
93 Init {
94 future: ListenerSelectFuture<C, U::Info>,
95 upgrade: U,
96 },
97 Upgrade {
98 future: Pin<Box<U::Future>>,
99 name: String,
100 },
101 Undefined,
102}
103
104impl<C, U> Unpin for InboundUpgradeApply<C, U>
105where
106 C: AsyncRead + AsyncWrite + Unpin,
107 U: InboundConnectionUpgrade<Negotiated<C>>,
108{
109}
110
111impl<C, U> Future for InboundUpgradeApply<C, U>
112where
113 C: AsyncRead + AsyncWrite + Unpin,
114 U: InboundConnectionUpgrade<Negotiated<C>>,
115{
116 type Output = Result<U::Output, UpgradeError<U::Error>>;
117
118 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
119 loop {
120 match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) {
121 InboundUpgradeApplyState::Init {
122 mut future,
123 upgrade,
124 } => {
125 let (info, io) = match Future::poll(Pin::new(&mut future), cx)? {
126 Poll::Ready(x) => x,
127 Poll::Pending => {
128 self.inner = InboundUpgradeApplyState::Init { future, upgrade };
129 return Poll::Pending;
130 }
131 };
132 self.inner = InboundUpgradeApplyState::Upgrade {
133 future: Box::pin(upgrade.upgrade_inbound(io, info.clone())),
134 name: info.as_ref().to_owned(),
135 };
136 }
137 InboundUpgradeApplyState::Upgrade { mut future, name } => {
138 match Future::poll(Pin::new(&mut future), cx) {
139 Poll::Pending => {
140 self.inner = InboundUpgradeApplyState::Upgrade { future, name };
141 return Poll::Pending;
142 }
143 Poll::Ready(Ok(x)) => {
144 log::trace!("Upgraded inbound stream to {name}");
145 return Poll::Ready(Ok(x));
146 }
147 Poll::Ready(Err(e)) => {
148 debug!("Failed to upgrade inbound stream to {name}");
149 return Poll::Ready(Err(UpgradeError::Apply(e)));
150 }
151 }
152 }
153 InboundUpgradeApplyState::Undefined => {
154 panic!("InboundUpgradeApplyState::poll called after completion")
155 }
156 }
157 }
158 }
159}
160
161pub struct OutboundUpgradeApply<C, U>
163where
164 C: AsyncRead + AsyncWrite + Unpin,
165 U: OutboundConnectionUpgrade<Negotiated<C>>,
166{
167 inner: OutboundUpgradeApplyState<C, U>,
168}
169
170enum OutboundUpgradeApplyState<C, U>
171where
172 C: AsyncRead + AsyncWrite + Unpin,
173 U: OutboundConnectionUpgrade<Negotiated<C>>,
174{
175 Init {
176 future: DialerSelectFuture<C, <U::InfoIter as IntoIterator>::IntoIter>,
177 upgrade: U,
178 },
179 Upgrade {
180 future: Pin<Box<U::Future>>,
181 name: String,
182 },
183 Undefined,
184}
185
186impl<C, U> Unpin for OutboundUpgradeApply<C, U>
187where
188 C: AsyncRead + AsyncWrite + Unpin,
189 U: OutboundConnectionUpgrade<Negotiated<C>>,
190{
191}
192
193impl<C, U> Future for OutboundUpgradeApply<C, U>
194where
195 C: AsyncRead + AsyncWrite + Unpin,
196 U: OutboundConnectionUpgrade<Negotiated<C>>,
197{
198 type Output = Result<U::Output, UpgradeError<U::Error>>;
199
200 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
201 loop {
202 match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) {
203 OutboundUpgradeApplyState::Init {
204 mut future,
205 upgrade,
206 } => {
207 let (info, connection) = match Future::poll(Pin::new(&mut future), cx)? {
208 Poll::Ready(x) => x,
209 Poll::Pending => {
210 self.inner = OutboundUpgradeApplyState::Init { future, upgrade };
211 return Poll::Pending;
212 }
213 };
214 self.inner = OutboundUpgradeApplyState::Upgrade {
215 future: Box::pin(upgrade.upgrade_outbound(connection, info.clone())),
216 name: info.as_ref().to_owned(),
217 };
218 }
219 OutboundUpgradeApplyState::Upgrade { mut future, name } => {
220 match Future::poll(Pin::new(&mut future), cx) {
221 Poll::Pending => {
222 self.inner = OutboundUpgradeApplyState::Upgrade { future, name };
223 return Poll::Pending;
224 }
225 Poll::Ready(Ok(x)) => {
226 log::trace!("Upgraded outbound stream to {name}",);
227 return Poll::Ready(Ok(x));
228 }
229 Poll::Ready(Err(e)) => {
230 debug!("Failed to upgrade outbound stream to {name}",);
231 return Poll::Ready(Err(UpgradeError::Apply(e)));
232 }
233 }
234 }
235 OutboundUpgradeApplyState::Undefined => {
236 panic!("OutboundUpgradeApplyState::poll called after completion")
237 }
238 }
239 }
240 }
241}