libp2p_core/upgrade/transfer.rs
1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Contains some helper futures for creating upgrades.
22
23use futures::prelude::*;
24use std::io;
25
26// TODO: these methods could be on an Ext trait to AsyncWrite
27
28/// Writes a message to the given socket with a length prefix appended to it. Also flushes the socket.
29///
30/// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is
31/// > compatible with what [`read_length_prefixed`] expects.
32pub async fn write_length_prefixed(
33 socket: &mut (impl AsyncWrite + Unpin),
34 data: impl AsRef<[u8]>,
35) -> Result<(), io::Error> {
36 write_varint(socket, data.as_ref().len()).await?;
37 socket.write_all(data.as_ref()).await?;
38 socket.flush().await?;
39
40 Ok(())
41}
42
43/// Writes a variable-length integer to the `socket`.
44///
45/// > **Note**: Does **NOT** flush the socket.
46pub async fn write_varint(
47 socket: &mut (impl AsyncWrite + Unpin),
48 len: usize,
49) -> Result<(), io::Error> {
50 let mut len_data = unsigned_varint::encode::usize_buffer();
51 let encoded_len = unsigned_varint::encode::usize(len, &mut len_data).len();
52 socket.write_all(&len_data[..encoded_len]).await?;
53
54 Ok(())
55}
56
57/// Reads a variable-length integer from the `socket`.
58///
59/// As a special exception, if the `socket` is empty and EOFs right at the beginning, then we
60/// return `Ok(0)`.
61///
62/// > **Note**: This function reads bytes one by one from the `socket`. It is therefore encouraged
63/// > to use some sort of buffering mechanism.
64pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result<usize, io::Error> {
65 let mut buffer = unsigned_varint::encode::usize_buffer();
66 let mut buffer_len = 0;
67
68 loop {
69 match socket.read(&mut buffer[buffer_len..buffer_len + 1]).await? {
70 0 => {
71 // Reaching EOF before finishing to read the length is an error, unless the EOF is
72 // at the very beginning of the substream, in which case we assume that the data is
73 // empty.
74 if buffer_len == 0 {
75 return Ok(0);
76 } else {
77 return Err(io::ErrorKind::UnexpectedEof.into());
78 }
79 }
80 n => debug_assert_eq!(n, 1),
81 }
82
83 buffer_len += 1;
84
85 match unsigned_varint::decode::usize(&buffer[..buffer_len]) {
86 Ok((len, _)) => return Ok(len),
87 Err(unsigned_varint::decode::Error::Overflow) => {
88 return Err(io::Error::new(
89 io::ErrorKind::InvalidData,
90 "overflow in variable-length integer",
91 ));
92 }
93 // TODO: why do we have a `__Nonexhaustive` variant in the error? I don't know how to process it
94 // Err(unsigned_varint::decode::Error::Insufficient) => {}
95 Err(_) => {}
96 }
97 }
98}
99
100/// Reads a length-prefixed message from the given socket.
101///
102/// The `max_size` parameter is the maximum size in bytes of the message that we accept. This is
103/// necessary in order to avoid DoS attacks where the remote sends us a message of several
104/// gigabytes.
105///
106/// > **Note**: Assumes that a variable-length prefix indicates the length of the message. This is
107/// > compatible with what [`write_length_prefixed`] does.
108pub async fn read_length_prefixed(
109 socket: &mut (impl AsyncRead + Unpin),
110 max_size: usize,
111) -> io::Result<Vec<u8>> {
112 let len = read_varint(socket).await?;
113 if len > max_size {
114 return Err(io::Error::new(
115 io::ErrorKind::InvalidData,
116 format!("Received data size ({len} bytes) exceeds maximum ({max_size} bytes)"),
117 ));
118 }
119
120 let mut buf = vec![0; len];
121 socket.read_exact(&mut buf).await?;
122
123 Ok(buf)
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn write_length_prefixed_works() {
132 let data = (0..rand::random::<usize>() % 10_000)
133 .map(|_| rand::random::<u8>())
134 .collect::<Vec<_>>();
135 let mut out = vec![0; 10_000];
136
137 futures::executor::block_on(async {
138 let mut socket = futures::io::Cursor::new(&mut out[..]);
139
140 write_length_prefixed(&mut socket, &data).await.unwrap();
141 socket.close().await.unwrap();
142 });
143
144 let (out_len, out_data) = unsigned_varint::decode::usize(&out).unwrap();
145 assert_eq!(out_len, data.len());
146 assert_eq!(&out_data[..out_len], &data[..]);
147 }
148
149 // TODO: rewrite these tests
150 /*
151 #[test]
152 fn read_one_works() {
153 let original_data = (0..rand::random::<usize>() % 10_000)
154 .map(|_| rand::random::<u8>())
155 .collect::<Vec<_>>();
156
157 let mut len_buf = unsigned_varint::encode::usize_buffer();
158 let len_buf = unsigned_varint::encode::usize(original_data.len(), &mut len_buf);
159
160 let mut in_buffer = len_buf.to_vec();
161 in_buffer.extend_from_slice(&original_data);
162
163 let future = read_one_then(Cursor::new(in_buffer), 10_000, (), move |out, ()| -> Result<_, ReadOneError> {
164 assert_eq!(out, original_data);
165 Ok(())
166 });
167
168 futures::executor::block_on(future).unwrap();
169 }
170
171 #[test]
172 fn read_one_zero_len() {
173 let future = read_one_then(Cursor::new(vec![0]), 10_000, (), move |out, ()| -> Result<_, ReadOneError> {
174 assert!(out.is_empty());
175 Ok(())
176 });
177
178 futures::executor::block_on(future).unwrap();
179 }
180
181 #[test]
182 fn read_checks_length() {
183 let mut len_buf = unsigned_varint::encode::u64_buffer();
184 let len_buf = unsigned_varint::encode::u64(5_000, &mut len_buf);
185
186 let mut in_buffer = len_buf.to_vec();
187 in_buffer.extend((0..5000).map(|_| 0));
188
189 let future = read_one_then(Cursor::new(in_buffer), 100, (), move |_, ()| -> Result<_, ReadOneError> {
190 Ok(())
191 });
192
193 match futures::executor::block_on(future) {
194 Err(ReadOneError::TooLarge { .. }) => (),
195 _ => panic!(),
196 }
197 }
198
199 #[test]
200 fn read_one_accepts_empty() {
201 let future = read_one_then(Cursor::new([]), 10_000, (), move |out, ()| -> Result<_, ReadOneError> {
202 assert!(out.is_empty());
203 Ok(())
204 });
205
206 futures::executor::block_on(future).unwrap();
207 }
208
209 #[test]
210 fn read_one_eof_before_len() {
211 let future = read_one_then(Cursor::new([0x80]), 10_000, (), move |_, ()| -> Result<(), ReadOneError> {
212 unreachable!()
213 });
214
215 match futures::executor::block_on(future) {
216 Err(ReadOneError::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof => (),
217 _ => panic!()
218 }
219 }*/
220}