litep2p/multistream_select/
negotiated.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::multistream_select::protocol::{
22    HeaderLine, Message, MessageReader, Protocol, ProtocolError,
23};
24
25use futures::{
26    io::{IoSlice, IoSliceMut},
27    prelude::*,
28    ready,
29};
30use pin_project::pin_project;
31use std::{
32    error::Error,
33    fmt, io, mem,
34    pin::Pin,
35    task::{Context, Poll},
36};
37
38const LOG_TARGET: &str = "litep2p::multistream-select";
39
40/// An I/O stream that has settled on an (application-layer) protocol to use.
41///
42/// A `Negotiated` represents an I/O stream that has _settled_ on a protocol
43/// to use. In particular, it is not implied that all of the protocol negotiation
44/// frames have yet been sent and / or received, just that the selected protocol
45/// is fully determined. This is to allow the last protocol negotiation frames
46/// sent by a peer to be combined in a single write, possibly piggy-backing
47/// data from the negotiated protocol on top.
48///
49/// Reading from a `Negotiated` I/O stream that still has pending negotiation
50/// protocol data to send implicitly triggers flushing of all yet unsent data.
51#[pin_project]
52#[derive(Debug)]
53pub struct Negotiated<TInner> {
54    #[pin]
55    state: State<TInner>,
56}
57
58/// A `Future` that waits on the completion of protocol negotiation.
59#[derive(Debug)]
60pub struct NegotiatedComplete<TInner> {
61    inner: Option<Negotiated<TInner>>,
62}
63
64impl<TInner> Future for NegotiatedComplete<TInner>
65where
66    // `Unpin` is required not because of
67    // implementation details but because we produce
68    // the `Negotiated` as the output of the
69    // future.
70    TInner: AsyncRead + AsyncWrite + Unpin,
71{
72    type Output = Result<Negotiated<TInner>, NegotiationError>;
73
74    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75        let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
76        match Negotiated::poll(Pin::new(&mut io), cx) {
77            Poll::Pending => {
78                self.inner = Some(io);
79                Poll::Pending
80            }
81            Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
82            Poll::Ready(Err(err)) => {
83                self.inner = Some(io);
84                Poll::Ready(Err(err))
85            }
86        }
87    }
88}
89
90impl<TInner> Negotiated<TInner> {
91    /// Creates a `Negotiated` in state [`State::Completed`].
92    pub(crate) fn completed(io: TInner) -> Self {
93        Negotiated {
94            state: State::Completed { io },
95        }
96    }
97
98    /// Creates a `Negotiated` in state [`State::Expecting`] that is still
99    /// expecting confirmation of the given `protocol`.
100    pub(crate) fn expecting(
101        io: MessageReader<TInner>,
102        protocol: Protocol,
103        header: Option<HeaderLine>,
104    ) -> Self {
105        Negotiated {
106            state: State::Expecting {
107                io,
108                protocol,
109                header,
110            },
111        }
112    }
113
114    pub fn inner(self) -> TInner {
115        match self.state {
116            State::Completed { io } => io,
117            _ => panic!("stream is not negotiated"),
118        }
119    }
120
121    /// Polls the `Negotiated` for completion.
122    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
123    where
124        TInner: AsyncRead + AsyncWrite + Unpin,
125    {
126        // Flush any pending negotiation data.
127        match self.as_mut().poll_flush(cx) {
128            Poll::Ready(Ok(())) => {}
129            Poll::Pending => return Poll::Pending,
130            Poll::Ready(Err(e)) => {
131                // If the remote closed the stream, it is important to still
132                // continue reading the data that was sent, if any.
133                if e.kind() != io::ErrorKind::WriteZero {
134                    return Poll::Ready(Err(e.into()));
135                }
136            }
137        }
138
139        let mut this = self.project();
140
141        if let StateProj::Completed { .. } = this.state.as_mut().project() {
142            return Poll::Ready(Ok(()));
143        }
144
145        // Read outstanding protocol negotiation messages.
146        loop {
147            match mem::replace(&mut *this.state, State::Invalid) {
148                State::Expecting {
149                    mut io,
150                    header,
151                    protocol,
152                } => {
153                    let msg = match Pin::new(&mut io).poll_next(cx)? {
154                        Poll::Ready(Some(msg)) => msg,
155                        Poll::Pending => {
156                            *this.state = State::Expecting {
157                                io,
158                                header,
159                                protocol,
160                            };
161                            return Poll::Pending;
162                        }
163                        Poll::Ready(None) => {
164                            return Poll::Ready(Err(ProtocolError::IoError(
165                                io::ErrorKind::UnexpectedEof.into(),
166                            )
167                            .into()));
168                        }
169                    };
170
171                    if let Message::Header(h) = &msg {
172                        if Some(h) == header.as_ref() {
173                            *this.state = State::Expecting {
174                                io,
175                                protocol,
176                                header: None,
177                            };
178                            continue;
179                        } else {
180                            // If we received a header message but it doesn't match the expected
181                            // one, or we have already received the message return an error.
182                            return Poll::Ready(Err(ProtocolError::InvalidMessage.into()));
183                        }
184                    }
185
186                    if let Message::Protocol(p) = &msg {
187                        if p.as_ref() == protocol.as_ref() {
188                            tracing::debug!(
189                                target: LOG_TARGET,
190                                "Negotiated: Received confirmation for protocol: {}",
191                                p
192                            );
193                            *this.state = State::Completed {
194                                io: io.into_inner(),
195                            };
196                            return Poll::Ready(Ok(()));
197                        }
198                    }
199
200                    return Poll::Ready(Err(NegotiationError::Failed));
201                }
202
203                _ => panic!("Negotiated: Invalid state"),
204            }
205        }
206    }
207
208    /// Returns a [`NegotiatedComplete`] future that waits for protocol
209    /// negotiation to complete.
210    pub fn complete(self) -> NegotiatedComplete<TInner> {
211        NegotiatedComplete { inner: Some(self) }
212    }
213}
214
215/// The states of a `Negotiated` I/O stream.
216#[pin_project(project = StateProj)]
217#[derive(Debug)]
218enum State<R> {
219    /// In this state, a `Negotiated` is still expecting to
220    /// receive confirmation of the protocol it has optimistically
221    /// settled on.
222    Expecting {
223        /// The underlying I/O stream.
224        #[pin]
225        io: MessageReader<R>,
226        /// The expected negotiation header/preamble (i.e. multistream-select version),
227        /// if one is still expected to be received.
228        header: Option<HeaderLine>,
229        /// The expected application protocol (i.e. name and version).
230        protocol: Protocol,
231    },
232
233    /// In this state, a protocol has been agreed upon and I/O
234    /// on the underlying stream can commence.
235    Completed {
236        #[pin]
237        io: R,
238    },
239
240    /// Temporary state while moving the `io` resource from
241    /// `Expecting` to `Completed`.
242    Invalid,
243}
244
245impl<TInner> AsyncRead for Negotiated<TInner>
246where
247    TInner: AsyncRead + AsyncWrite + Unpin,
248{
249    fn poll_read(
250        mut self: Pin<&mut Self>,
251        cx: &mut Context<'_>,
252        buf: &mut [u8],
253    ) -> Poll<Result<usize, io::Error>> {
254        loop {
255            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
256                // If protocol negotiation is complete, commence with reading.
257                return io.poll_read(cx, buf);
258            }
259
260            // Poll the `Negotiated`, driving protocol negotiation to completion,
261            // including flushing of any remaining data.
262            match self.as_mut().poll(cx) {
263                Poll::Ready(Ok(())) => {}
264                Poll::Pending => return Poll::Pending,
265                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
266            }
267        }
268    }
269
270    // TODO: implement once method is stabilized in the futures crate
271    /*unsafe fn initializer(&self) -> Initializer {
272        match &self.state {
273            State::Completed { io, .. } => io.initializer(),
274            State::Expecting { io, .. } => io.inner_ref().initializer(),
275            State::Invalid => panic!("Negotiated: Invalid state"),
276        }
277    }*/
278
279    fn poll_read_vectored(
280        mut self: Pin<&mut Self>,
281        cx: &mut Context<'_>,
282        bufs: &mut [IoSliceMut<'_>],
283    ) -> Poll<Result<usize, io::Error>> {
284        loop {
285            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
286                // If protocol negotiation is complete, commence with reading.
287                return io.poll_read_vectored(cx, bufs);
288            }
289
290            // Poll the `Negotiated`, driving protocol negotiation to completion,
291            // including flushing of any remaining data.
292            match self.as_mut().poll(cx) {
293                Poll::Ready(Ok(())) => {}
294                Poll::Pending => return Poll::Pending,
295                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
296            }
297        }
298    }
299}
300
301impl<TInner> AsyncWrite for Negotiated<TInner>
302where
303    TInner: AsyncWrite + AsyncRead + Unpin,
304{
305    fn poll_write(
306        self: Pin<&mut Self>,
307        cx: &mut Context<'_>,
308        buf: &[u8],
309    ) -> Poll<Result<usize, io::Error>> {
310        match self.project().state.project() {
311            StateProj::Completed { io } => io.poll_write(cx, buf),
312            StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
313            StateProj::Invalid => panic!("Negotiated: Invalid state"),
314        }
315    }
316
317    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
318        match self.project().state.project() {
319            StateProj::Completed { io } => io.poll_flush(cx),
320            StateProj::Expecting { io, .. } => io.poll_flush(cx),
321            StateProj::Invalid => panic!("Negotiated: Invalid state"),
322        }
323    }
324
325    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
326        // Ensure all data has been flushed and expected negotiation messages
327        // have been received.
328        ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
329        ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
330
331        // Continue with the shutdown of the underlying I/O stream.
332        match self.project().state.project() {
333            StateProj::Completed { io, .. } => io.poll_close(cx),
334            StateProj::Expecting { io, .. } => io.poll_close(cx),
335            StateProj::Invalid => panic!("Negotiated: Invalid state"),
336        }
337    }
338
339    fn poll_write_vectored(
340        self: Pin<&mut Self>,
341        cx: &mut Context<'_>,
342        bufs: &[IoSlice<'_>],
343    ) -> Poll<Result<usize, io::Error>> {
344        match self.project().state.project() {
345            StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
346            StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
347            StateProj::Invalid => panic!("Negotiated: Invalid state"),
348        }
349    }
350}
351
352/// Error that can happen when negotiating a protocol with the remote.
353#[derive(Debug, thiserror::Error, PartialEq)]
354pub enum NegotiationError {
355    /// A protocol error occurred during the negotiation.
356    #[error("A protocol error occurred during the negotiation: `{0:?}`")]
357    ProtocolError(#[from] ProtocolError),
358
359    /// Protocol negotiation failed because no protocol could be agreed upon.
360    #[error("Protocol negotiation failed.")]
361    Failed,
362}
363
364impl From<io::Error> for NegotiationError {
365    fn from(err: io::Error) -> NegotiationError {
366        ProtocolError::from(err).into()
367    }
368}
369
370impl From<NegotiationError> for io::Error {
371    fn from(err: NegotiationError) -> io::Error {
372        if let NegotiationError::ProtocolError(e) = err {
373            return e.into();
374        }
375        io::Error::new(io::ErrorKind::Other, err)
376    }
377}