libp2p_noise/io/
framed.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! This module provides a `Sink` and `Stream` for length-delimited
22//! Noise protocol messages in form of [`NoiseFramed`].
23
24use crate::io::Output;
25use crate::{protocol::PublicKey, Error};
26use bytes::{Bytes, BytesMut};
27use futures::prelude::*;
28use futures::ready;
29use log::{debug, trace};
30use std::{
31    fmt, io,
32    pin::Pin,
33    task::{Context, Poll},
34};
35
36/// Max. size of a noise message.
37const MAX_NOISE_MSG_LEN: usize = 65535;
38/// Space given to the encryption buffer to hold key material.
39const EXTRA_ENCRYPT_SPACE: usize = 1024;
40/// Max. length for Noise protocol message payloads.
41pub(crate) const MAX_FRAME_LEN: usize = MAX_NOISE_MSG_LEN - EXTRA_ENCRYPT_SPACE;
42static_assertions::const_assert! {
43    MAX_FRAME_LEN + EXTRA_ENCRYPT_SPACE <= MAX_NOISE_MSG_LEN
44}
45
46/// A `NoiseFramed` is a `Sink` and `Stream` for length-delimited
47/// Noise protocol messages.
48///
49/// `T` is the type of the underlying I/O resource and `S` the
50/// type of the Noise session state.
51pub(crate) struct NoiseFramed<T, S> {
52    io: T,
53    session: S,
54    read_state: ReadState,
55    write_state: WriteState,
56    read_buffer: Vec<u8>,
57    write_buffer: Vec<u8>,
58    decrypt_buffer: BytesMut,
59}
60
61impl<T, S> fmt::Debug for NoiseFramed<T, S> {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.debug_struct("NoiseFramed")
64            .field("read_state", &self.read_state)
65            .field("write_state", &self.write_state)
66            .finish()
67    }
68}
69
70impl<T> NoiseFramed<T, snow::HandshakeState> {
71    /// Creates a nwe `NoiseFramed` for beginning a Noise protocol handshake.
72    pub(crate) fn new(io: T, state: snow::HandshakeState) -> Self {
73        NoiseFramed {
74            io,
75            session: state,
76            read_state: ReadState::Ready,
77            write_state: WriteState::Ready,
78            read_buffer: Vec::new(),
79            write_buffer: Vec::new(),
80            decrypt_buffer: BytesMut::new(),
81        }
82    }
83
84    pub(crate) fn is_initiator(&self) -> bool {
85        self.session.is_initiator()
86    }
87
88    pub(crate) fn is_responder(&self) -> bool {
89        !self.session.is_initiator()
90    }
91
92    /// Converts the `NoiseFramed` into a `NoiseOutput` encrypted data stream
93    /// once the handshake is complete, including the static DH [`PublicKey`]
94    /// of the remote, if received.
95    ///
96    /// If the underlying Noise protocol session state does not permit
97    /// transitioning to transport mode because the handshake is incomplete,
98    /// an error is returned. Similarly if the remote's static DH key, if
99    /// present, cannot be parsed.
100    pub(crate) fn into_transport(self) -> Result<(PublicKey, Output<T>), Error> {
101        let dh_remote_pubkey = self.session.get_remote_static().ok_or_else(|| {
102            Error::Io(io::Error::new(
103                io::ErrorKind::Other,
104                "expect key to always be present at end of XX session",
105            ))
106        })?;
107
108        let dh_remote_pubkey = PublicKey::from_slice(dh_remote_pubkey)?;
109
110        let io = NoiseFramed {
111            session: self.session.into_transport_mode()?,
112            io: self.io,
113            read_state: ReadState::Ready,
114            write_state: WriteState::Ready,
115            read_buffer: self.read_buffer,
116            write_buffer: self.write_buffer,
117            decrypt_buffer: self.decrypt_buffer,
118        };
119
120        Ok((dh_remote_pubkey, Output::new(io)))
121    }
122}
123
124/// The states for reading Noise protocol frames.
125#[derive(Debug)]
126enum ReadState {
127    /// Ready to read another frame.
128    Ready,
129    /// Reading frame length.
130    ReadLen { buf: [u8; 2], off: usize },
131    /// Reading frame data.
132    ReadData { len: usize, off: usize },
133    /// EOF has been reached (terminal state).
134    ///
135    /// The associated result signals if the EOF was unexpected or not.
136    Eof(Result<(), ()>),
137    /// A decryption error occurred (terminal state).
138    DecErr,
139}
140
141/// The states for writing Noise protocol frames.
142#[derive(Debug)]
143enum WriteState {
144    /// Ready to write another frame.
145    Ready,
146    /// Writing the frame length.
147    WriteLen {
148        len: usize,
149        buf: [u8; 2],
150        off: usize,
151    },
152    /// Writing the frame data.
153    WriteData { len: usize, off: usize },
154    /// EOF has been reached unexpectedly (terminal state).
155    Eof,
156    /// An encryption error occurred (terminal state).
157    EncErr,
158}
159
160impl WriteState {
161    fn is_ready(&self) -> bool {
162        if let WriteState::Ready = self {
163            return true;
164        }
165        false
166    }
167}
168
169impl<T, S> futures::stream::Stream for NoiseFramed<T, S>
170where
171    T: AsyncRead + Unpin,
172    S: SessionState + Unpin,
173{
174    type Item = io::Result<Bytes>;
175
176    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
177        let this = Pin::into_inner(self);
178        loop {
179            trace!("read state: {:?}", this.read_state);
180            match this.read_state {
181                ReadState::Ready => {
182                    this.read_state = ReadState::ReadLen {
183                        buf: [0, 0],
184                        off: 0,
185                    };
186                }
187                ReadState::ReadLen { mut buf, mut off } => {
188                    let n = match read_frame_len(&mut this.io, cx, &mut buf, &mut off) {
189                        Poll::Ready(Ok(Some(n))) => n,
190                        Poll::Ready(Ok(None)) => {
191                            trace!("read: eof");
192                            this.read_state = ReadState::Eof(Ok(()));
193                            return Poll::Ready(None);
194                        }
195                        Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
196                        Poll::Pending => {
197                            this.read_state = ReadState::ReadLen { buf, off };
198                            return Poll::Pending;
199                        }
200                    };
201                    trace!("read: frame len = {}", n);
202                    if n == 0 {
203                        trace!("read: empty frame");
204                        this.read_state = ReadState::Ready;
205                        continue;
206                    }
207                    this.read_buffer.resize(usize::from(n), 0u8);
208                    this.read_state = ReadState::ReadData {
209                        len: usize::from(n),
210                        off: 0,
211                    }
212                }
213                ReadState::ReadData { len, ref mut off } => {
214                    let n = {
215                        let f =
216                            Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off..len]);
217                        match ready!(f) {
218                            Ok(n) => n,
219                            Err(e) => return Poll::Ready(Some(Err(e))),
220                        }
221                    };
222                    trace!("read: {}/{} bytes", *off + n, len);
223                    if n == 0 {
224                        trace!("read: eof");
225                        this.read_state = ReadState::Eof(Err(()));
226                        return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
227                    }
228                    *off += n;
229                    if len == *off {
230                        trace!("read: decrypting {} bytes", len);
231                        this.decrypt_buffer.resize(len, 0);
232                        if let Ok(n) = this
233                            .session
234                            .read_message(&this.read_buffer, &mut this.decrypt_buffer)
235                        {
236                            this.decrypt_buffer.truncate(n);
237                            trace!("read: payload len = {} bytes", n);
238                            this.read_state = ReadState::Ready;
239                            // Return an immutable view into the current buffer.
240                            // If the view is dropped before the next frame is
241                            // read, the `BytesMut` will reuse the same buffer
242                            // for the next frame.
243                            let view = this.decrypt_buffer.split().freeze();
244                            return Poll::Ready(Some(Ok(view)));
245                        } else {
246                            debug!("read: decryption error");
247                            this.read_state = ReadState::DecErr;
248                            return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into())));
249                        }
250                    }
251                }
252                ReadState::Eof(Ok(())) => {
253                    trace!("read: eof");
254                    return Poll::Ready(None);
255                }
256                ReadState::Eof(Err(())) => {
257                    trace!("read: eof (unexpected)");
258                    return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
259                }
260                ReadState::DecErr => {
261                    return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into())))
262                }
263            }
264        }
265    }
266}
267
268impl<T, S> futures::sink::Sink<&Vec<u8>> for NoiseFramed<T, S>
269where
270    T: AsyncWrite + Unpin,
271    S: SessionState + Unpin,
272{
273    type Error = io::Error;
274
275    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276        let this = Pin::into_inner(self);
277        loop {
278            trace!("write state {:?}", this.write_state);
279            match this.write_state {
280                WriteState::Ready => {
281                    return Poll::Ready(Ok(()));
282                }
283                WriteState::WriteLen { len, buf, mut off } => {
284                    trace!("write: frame len ({}, {:?}, {}/2)", len, buf, off);
285                    match write_frame_len(&mut this.io, cx, &buf, &mut off) {
286                        Poll::Ready(Ok(true)) => (),
287                        Poll::Ready(Ok(false)) => {
288                            trace!("write: eof");
289                            this.write_state = WriteState::Eof;
290                            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
291                        }
292                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
293                        Poll::Pending => {
294                            this.write_state = WriteState::WriteLen { len, buf, off };
295                            return Poll::Pending;
296                        }
297                    }
298                    this.write_state = WriteState::WriteData { len, off: 0 }
299                }
300                WriteState::WriteData { len, ref mut off } => {
301                    let n = {
302                        let f =
303                            Pin::new(&mut this.io).poll_write(cx, &this.write_buffer[*off..len]);
304                        match ready!(f) {
305                            Ok(n) => n,
306                            Err(e) => return Poll::Ready(Err(e)),
307                        }
308                    };
309                    if n == 0 {
310                        trace!("write: eof");
311                        this.write_state = WriteState::Eof;
312                        return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
313                    }
314                    *off += n;
315                    trace!("write: {}/{} bytes written", *off, len);
316                    if len == *off {
317                        trace!("write: finished with {} bytes", len);
318                        this.write_state = WriteState::Ready;
319                    }
320                }
321                WriteState::Eof => {
322                    trace!("write: eof");
323                    return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
324                }
325                WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())),
326            }
327        }
328    }
329
330    fn start_send(self: Pin<&mut Self>, frame: &Vec<u8>) -> Result<(), Self::Error> {
331        assert!(frame.len() <= MAX_FRAME_LEN);
332        let this = Pin::into_inner(self);
333        assert!(this.write_state.is_ready());
334
335        this.write_buffer
336            .resize(frame.len() + EXTRA_ENCRYPT_SPACE, 0u8);
337        match this
338            .session
339            .write_message(frame, &mut this.write_buffer[..])
340        {
341            Ok(n) => {
342                trace!("write: cipher text len = {} bytes", n);
343                this.write_buffer.truncate(n);
344                this.write_state = WriteState::WriteLen {
345                    len: n,
346                    buf: u16::to_be_bytes(n as u16),
347                    off: 0,
348                };
349                Ok(())
350            }
351            Err(e) => {
352                log::error!("encryption error: {:?}", e);
353                this.write_state = WriteState::EncErr;
354                Err(io::ErrorKind::InvalidData.into())
355            }
356        }
357    }
358
359    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
360        ready!(self.as_mut().poll_ready(cx))?;
361        Pin::new(&mut self.io).poll_flush(cx)
362    }
363
364    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
365        ready!(self.as_mut().poll_flush(cx))?;
366        Pin::new(&mut self.io).poll_close(cx)
367    }
368}
369
370/// A stateful context in which Noise protocol messages can be read and written.
371pub(crate) trait SessionState {
372    fn read_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error>;
373    fn write_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error>;
374}
375
376impl SessionState for snow::HandshakeState {
377    fn read_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error> {
378        self.read_message(msg, buf)
379    }
380
381    fn write_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error> {
382        self.write_message(msg, buf)
383    }
384}
385
386impl SessionState for snow::TransportState {
387    fn read_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error> {
388        self.read_message(msg, buf)
389    }
390
391    fn write_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error> {
392        self.write_message(msg, buf)
393    }
394}
395
396/// Read 2 bytes as frame length from the given source into the given buffer.
397///
398/// Panics if `off >= 2`.
399///
400/// When [`Poll::Pending`] is returned, the given buffer and offset
401/// may have been updated (i.e. a byte may have been read) and must be preserved
402/// for the next invocation.
403///
404/// Returns `None` if EOF has been encountered.
405fn read_frame_len<R: AsyncRead + Unpin>(
406    mut io: &mut R,
407    cx: &mut Context<'_>,
408    buf: &mut [u8; 2],
409    off: &mut usize,
410) -> Poll<io::Result<Option<u16>>> {
411    loop {
412        match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off..])) {
413            Ok(n) => {
414                if n == 0 {
415                    return Poll::Ready(Ok(None));
416                }
417                *off += n;
418                if *off == 2 {
419                    return Poll::Ready(Ok(Some(u16::from_be_bytes(*buf))));
420                }
421            }
422            Err(e) => {
423                return Poll::Ready(Err(e));
424            }
425        }
426    }
427}
428
429/// Write 2 bytes as frame length from the given buffer into the given sink.
430///
431/// Panics if `off >= 2`.
432///
433/// When [`Poll::Pending`] is returned, the given offset
434/// may have been updated (i.e. a byte may have been written) and must
435/// be preserved for the next invocation.
436///
437/// Returns `false` if EOF has been encountered.
438fn write_frame_len<W: AsyncWrite + Unpin>(
439    mut io: &mut W,
440    cx: &mut Context<'_>,
441    buf: &[u8; 2],
442    off: &mut usize,
443) -> Poll<io::Result<bool>> {
444    loop {
445        match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off..])) {
446            Ok(n) => {
447                if n == 0 {
448                    return Poll::Ready(Ok(false));
449                }
450                *off += n;
451                if *off == 2 {
452                    return Poll::Ready(Ok(true));
453                }
454            }
455            Err(e) => {
456                return Poll::Ready(Err(e));
457            }
458        }
459    }
460}