litep2p/multistream_select/
negotiated.rs1use crate::multistream_select::protocol::{
22 HeaderLine, Message, MessageReader, Protocol, ProtocolError,
23};
24
25use futures::{
26 io::{IoSlice, IoSliceMut},
27 prelude::*,
28 ready,
29};
30use pin_project::pin_project;
31use std::{
32 error::Error,
33 fmt, io, mem,
34 pin::Pin,
35 task::{Context, Poll},
36};
37
38const LOG_TARGET: &str = "litep2p::multistream-select";
39
40#[pin_project]
52#[derive(Debug)]
53pub struct Negotiated<TInner> {
54 #[pin]
55 state: State<TInner>,
56}
57
58#[derive(Debug)]
60pub struct NegotiatedComplete<TInner> {
61 inner: Option<Negotiated<TInner>>,
62}
63
64impl<TInner> Future for NegotiatedComplete<TInner>
65where
66 TInner: AsyncRead + AsyncWrite + Unpin,
71{
72 type Output = Result<Negotiated<TInner>, NegotiationError>;
73
74 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75 let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
76 match Negotiated::poll(Pin::new(&mut io), cx) {
77 Poll::Pending => {
78 self.inner = Some(io);
79 Poll::Pending
80 }
81 Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
82 Poll::Ready(Err(err)) => {
83 self.inner = Some(io);
84 Poll::Ready(Err(err))
85 }
86 }
87 }
88}
89
90impl<TInner> Negotiated<TInner> {
91 pub(crate) fn completed(io: TInner) -> Self {
93 Negotiated {
94 state: State::Completed { io },
95 }
96 }
97
98 pub(crate) fn expecting(
101 io: MessageReader<TInner>,
102 protocol: Protocol,
103 header: Option<HeaderLine>,
104 ) -> Self {
105 Negotiated {
106 state: State::Expecting {
107 io,
108 protocol,
109 header,
110 },
111 }
112 }
113
114 pub fn inner(self) -> TInner {
115 match self.state {
116 State::Completed { io } => io,
117 _ => panic!("stream is not negotiated"),
118 }
119 }
120
121 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
123 where
124 TInner: AsyncRead + AsyncWrite + Unpin,
125 {
126 match self.as_mut().poll_flush(cx) {
128 Poll::Ready(Ok(())) => {}
129 Poll::Pending => return Poll::Pending,
130 Poll::Ready(Err(e)) => {
131 if e.kind() != io::ErrorKind::WriteZero {
134 return Poll::Ready(Err(e.into()));
135 }
136 }
137 }
138
139 let mut this = self.project();
140
141 if let StateProj::Completed { .. } = this.state.as_mut().project() {
142 return Poll::Ready(Ok(()));
143 }
144
145 loop {
147 match mem::replace(&mut *this.state, State::Invalid) {
148 State::Expecting {
149 mut io,
150 header,
151 protocol,
152 } => {
153 let msg = match Pin::new(&mut io).poll_next(cx)? {
154 Poll::Ready(Some(msg)) => msg,
155 Poll::Pending => {
156 *this.state = State::Expecting {
157 io,
158 header,
159 protocol,
160 };
161 return Poll::Pending;
162 }
163 Poll::Ready(None) => {
164 return Poll::Ready(Err(ProtocolError::IoError(
165 io::ErrorKind::UnexpectedEof.into(),
166 )
167 .into()));
168 }
169 };
170
171 if let Message::Header(h) = &msg {
172 if Some(h) == header.as_ref() {
173 *this.state = State::Expecting {
174 io,
175 protocol,
176 header: None,
177 };
178 continue;
179 } else {
180 return Poll::Ready(Err(ProtocolError::InvalidMessage.into()));
183 }
184 }
185
186 if let Message::Protocol(p) = &msg {
187 if p.as_ref() == protocol.as_ref() {
188 tracing::debug!(
189 target: LOG_TARGET,
190 "Negotiated: Received confirmation for protocol: {}",
191 p
192 );
193 *this.state = State::Completed {
194 io: io.into_inner(),
195 };
196 return Poll::Ready(Ok(()));
197 }
198 }
199
200 return Poll::Ready(Err(NegotiationError::Failed));
201 }
202
203 _ => panic!("Negotiated: Invalid state"),
204 }
205 }
206 }
207
208 pub fn complete(self) -> NegotiatedComplete<TInner> {
211 NegotiatedComplete { inner: Some(self) }
212 }
213}
214
215#[pin_project(project = StateProj)]
217#[derive(Debug)]
218enum State<R> {
219 Expecting {
223 #[pin]
225 io: MessageReader<R>,
226 header: Option<HeaderLine>,
229 protocol: Protocol,
231 },
232
233 Completed {
236 #[pin]
237 io: R,
238 },
239
240 Invalid,
243}
244
245impl<TInner> AsyncRead for Negotiated<TInner>
246where
247 TInner: AsyncRead + AsyncWrite + Unpin,
248{
249 fn poll_read(
250 mut self: Pin<&mut Self>,
251 cx: &mut Context<'_>,
252 buf: &mut [u8],
253 ) -> Poll<Result<usize, io::Error>> {
254 loop {
255 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
256 return io.poll_read(cx, buf);
258 }
259
260 match self.as_mut().poll(cx) {
263 Poll::Ready(Ok(())) => {}
264 Poll::Pending => return Poll::Pending,
265 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
266 }
267 }
268 }
269
270 fn poll_read_vectored(
280 mut self: Pin<&mut Self>,
281 cx: &mut Context<'_>,
282 bufs: &mut [IoSliceMut<'_>],
283 ) -> Poll<Result<usize, io::Error>> {
284 loop {
285 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
286 return io.poll_read_vectored(cx, bufs);
288 }
289
290 match self.as_mut().poll(cx) {
293 Poll::Ready(Ok(())) => {}
294 Poll::Pending => return Poll::Pending,
295 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
296 }
297 }
298 }
299}
300
301impl<TInner> AsyncWrite for Negotiated<TInner>
302where
303 TInner: AsyncWrite + AsyncRead + Unpin,
304{
305 fn poll_write(
306 self: Pin<&mut Self>,
307 cx: &mut Context<'_>,
308 buf: &[u8],
309 ) -> Poll<Result<usize, io::Error>> {
310 match self.project().state.project() {
311 StateProj::Completed { io } => io.poll_write(cx, buf),
312 StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
313 StateProj::Invalid => panic!("Negotiated: Invalid state"),
314 }
315 }
316
317 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
318 match self.project().state.project() {
319 StateProj::Completed { io } => io.poll_flush(cx),
320 StateProj::Expecting { io, .. } => io.poll_flush(cx),
321 StateProj::Invalid => panic!("Negotiated: Invalid state"),
322 }
323 }
324
325 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
326 ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
329 ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
330
331 match self.project().state.project() {
333 StateProj::Completed { io, .. } => io.poll_close(cx),
334 StateProj::Expecting { io, .. } => io.poll_close(cx),
335 StateProj::Invalid => panic!("Negotiated: Invalid state"),
336 }
337 }
338
339 fn poll_write_vectored(
340 self: Pin<&mut Self>,
341 cx: &mut Context<'_>,
342 bufs: &[IoSlice<'_>],
343 ) -> Poll<Result<usize, io::Error>> {
344 match self.project().state.project() {
345 StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
346 StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
347 StateProj::Invalid => panic!("Negotiated: Invalid state"),
348 }
349 }
350}
351
352#[derive(Debug, thiserror::Error, PartialEq)]
354pub enum NegotiationError {
355 #[error("A protocol error occurred during the negotiation: `{0:?}`")]
357 ProtocolError(#[from] ProtocolError),
358
359 #[error("Protocol negotiation failed.")]
361 Failed,
362}
363
364impl From<io::Error> for NegotiationError {
365 fn from(err: io::Error) -> NegotiationError {
366 ProtocolError::from(err).into()
367 }
368}
369
370impl From<NegotiationError> for io::Error {
371 fn from(err: NegotiationError) -> io::Error {
372 if let NegotiationError::ProtocolError(e) = err {
373 return e.into();
374 }
375 io::Error::new(io::ErrorKind::Other, err)
376 }
377}