libp2p_core/upgrade/
transfer.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Contains some helper futures for creating upgrades.
22
23use futures::prelude::*;
24use std::io;
25
26// TODO: these methods could be on an Ext trait to AsyncWrite
27
28/// Writes a message to the given socket with a length prefix appended to it. Also flushes the socket.
29///
30/// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is
31/// >           compatible with what [`read_length_prefixed`] expects.
32pub async fn write_length_prefixed(
33    socket: &mut (impl AsyncWrite + Unpin),
34    data: impl AsRef<[u8]>,
35) -> Result<(), io::Error> {
36    write_varint(socket, data.as_ref().len()).await?;
37    socket.write_all(data.as_ref()).await?;
38    socket.flush().await?;
39
40    Ok(())
41}
42
43/// Writes a variable-length integer to the `socket`.
44///
45/// > **Note**: Does **NOT** flush the socket.
46pub async fn write_varint(
47    socket: &mut (impl AsyncWrite + Unpin),
48    len: usize,
49) -> Result<(), io::Error> {
50    let mut len_data = unsigned_varint::encode::usize_buffer();
51    let encoded_len = unsigned_varint::encode::usize(len, &mut len_data).len();
52    socket.write_all(&len_data[..encoded_len]).await?;
53
54    Ok(())
55}
56
57/// Reads a variable-length integer from the `socket`.
58///
59/// As a special exception, if the `socket` is empty and EOFs right at the beginning, then we
60/// return `Ok(0)`.
61///
62/// > **Note**: This function reads bytes one by one from the `socket`. It is therefore encouraged
63/// >           to use some sort of buffering mechanism.
64pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result<usize, io::Error> {
65    let mut buffer = unsigned_varint::encode::usize_buffer();
66    let mut buffer_len = 0;
67
68    loop {
69        match socket.read(&mut buffer[buffer_len..buffer_len + 1]).await? {
70            0 => {
71                // Reaching EOF before finishing to read the length is an error, unless the EOF is
72                // at the very beginning of the substream, in which case we assume that the data is
73                // empty.
74                if buffer_len == 0 {
75                    return Ok(0);
76                } else {
77                    return Err(io::ErrorKind::UnexpectedEof.into());
78                }
79            }
80            n => debug_assert_eq!(n, 1),
81        }
82
83        buffer_len += 1;
84
85        match unsigned_varint::decode::usize(&buffer[..buffer_len]) {
86            Ok((len, _)) => return Ok(len),
87            Err(unsigned_varint::decode::Error::Overflow) => {
88                return Err(io::Error::new(
89                    io::ErrorKind::InvalidData,
90                    "overflow in variable-length integer",
91                ));
92            }
93            // TODO: why do we have a `__Nonexhaustive` variant in the error? I don't know how to process it
94            // Err(unsigned_varint::decode::Error::Insufficient) => {}
95            Err(_) => {}
96        }
97    }
98}
99
100/// Reads a length-prefixed message from the given socket.
101///
102/// The `max_size` parameter is the maximum size in bytes of the message that we accept. This is
103/// necessary in order to avoid DoS attacks where the remote sends us a message of several
104/// gigabytes.
105///
106/// > **Note**: Assumes that a variable-length prefix indicates the length of the message. This is
107/// >           compatible with what [`write_length_prefixed`] does.
108pub async fn read_length_prefixed(
109    socket: &mut (impl AsyncRead + Unpin),
110    max_size: usize,
111) -> io::Result<Vec<u8>> {
112    let len = read_varint(socket).await?;
113    if len > max_size {
114        return Err(io::Error::new(
115            io::ErrorKind::InvalidData,
116            format!("Received data size ({len} bytes) exceeds maximum ({max_size} bytes)"),
117        ));
118    }
119
120    let mut buf = vec![0; len];
121    socket.read_exact(&mut buf).await?;
122
123    Ok(buf)
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn write_length_prefixed_works() {
132        let data = (0..rand::random::<usize>() % 10_000)
133            .map(|_| rand::random::<u8>())
134            .collect::<Vec<_>>();
135        let mut out = vec![0; 10_000];
136
137        futures::executor::block_on(async {
138            let mut socket = futures::io::Cursor::new(&mut out[..]);
139
140            write_length_prefixed(&mut socket, &data).await.unwrap();
141            socket.close().await.unwrap();
142        });
143
144        let (out_len, out_data) = unsigned_varint::decode::usize(&out).unwrap();
145        assert_eq!(out_len, data.len());
146        assert_eq!(&out_data[..out_len], &data[..]);
147    }
148
149    // TODO: rewrite these tests
150    /*
151    #[test]
152    fn read_one_works() {
153        let original_data = (0..rand::random::<usize>() % 10_000)
154            .map(|_| rand::random::<u8>())
155            .collect::<Vec<_>>();
156
157        let mut len_buf = unsigned_varint::encode::usize_buffer();
158        let len_buf = unsigned_varint::encode::usize(original_data.len(), &mut len_buf);
159
160        let mut in_buffer = len_buf.to_vec();
161        in_buffer.extend_from_slice(&original_data);
162
163        let future = read_one_then(Cursor::new(in_buffer), 10_000, (), move |out, ()| -> Result<_, ReadOneError> {
164            assert_eq!(out, original_data);
165            Ok(())
166        });
167
168        futures::executor::block_on(future).unwrap();
169    }
170
171    #[test]
172    fn read_one_zero_len() {
173        let future = read_one_then(Cursor::new(vec![0]), 10_000, (), move |out, ()| -> Result<_, ReadOneError> {
174            assert!(out.is_empty());
175            Ok(())
176        });
177
178        futures::executor::block_on(future).unwrap();
179    }
180
181    #[test]
182    fn read_checks_length() {
183        let mut len_buf = unsigned_varint::encode::u64_buffer();
184        let len_buf = unsigned_varint::encode::u64(5_000, &mut len_buf);
185
186        let mut in_buffer = len_buf.to_vec();
187        in_buffer.extend((0..5000).map(|_| 0));
188
189        let future = read_one_then(Cursor::new(in_buffer), 100, (), move |_, ()| -> Result<_, ReadOneError> {
190            Ok(())
191        });
192
193        match futures::executor::block_on(future) {
194            Err(ReadOneError::TooLarge { .. }) => (),
195            _ => panic!(),
196        }
197    }
198
199    #[test]
200    fn read_one_accepts_empty() {
201        let future = read_one_then(Cursor::new([]), 10_000, (), move |out, ()| -> Result<_, ReadOneError> {
202            assert!(out.is_empty());
203            Ok(())
204        });
205
206        futures::executor::block_on(future).unwrap();
207    }
208
209    #[test]
210    fn read_one_eof_before_len() {
211        let future = read_one_then(Cursor::new([0x80]), 10_000, (), move |_, ()| -> Result<(), ReadOneError> {
212            unreachable!()
213        });
214
215        match futures::executor::block_on(future) {
216            Err(ReadOneError::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof => (),
217            _ => panic!()
218        }
219    }*/
220}