litep2p/multistream_select/
protocol.rs1use crate::{
29 codec::unsigned_varint::UnsignedVarint,
30 error::Error as Litep2pError,
31 multistream_select::{
32 length_delimited::{LengthDelimited, LengthDelimitedReader},
33 Version,
34 },
35};
36
37use bytes::{BufMut, Bytes, BytesMut};
38use futures::{io::IoSlice, prelude::*, ready};
39use std::{
40 convert::TryFrom,
41 error::Error,
42 fmt, io,
43 pin::Pin,
44 task::{Context, Poll},
45};
46use unsigned_varint as uvi;
47
48const MAX_PROTOCOLS: usize = 1000;
50
51pub const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
53const MSG_PROTOCOL_NA: &[u8] = b"na\n";
55const MSG_LS: &[u8] = b"ls\n";
57const LOG_TARGET: &str = "litep2p::multistream-select";
59
60#[derive(Copy, Clone, Debug, PartialEq, Eq)]
64pub enum HeaderLine {
65 V1,
67}
68
69impl From<Version> for HeaderLine {
70 fn from(v: Version) -> HeaderLine {
71 match v {
72 Version::V1 | Version::V1Lazy => HeaderLine::V1,
73 }
74 }
75}
76
77#[derive(Clone, Debug, PartialEq, Eq)]
79pub struct Protocol(Bytes);
80
81impl AsRef<[u8]> for Protocol {
82 fn as_ref(&self) -> &[u8] {
83 self.0.as_ref()
84 }
85}
86
87impl TryFrom<Bytes> for Protocol {
88 type Error = ProtocolError;
89
90 fn try_from(value: Bytes) -> Result<Self, Self::Error> {
91 if !value.as_ref().starts_with(b"/") {
92 return Err(ProtocolError::InvalidProtocol);
93 }
94 Ok(Protocol(value))
95 }
96}
97
98impl TryFrom<&[u8]> for Protocol {
99 type Error = ProtocolError;
100
101 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
102 Self::try_from(Bytes::copy_from_slice(value))
103 }
104}
105
106impl fmt::Display for Protocol {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(f, "{}", String::from_utf8_lossy(&self.0))
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum Message {
118 Header(HeaderLine),
121 Protocol(Protocol),
123 ListProtocols,
126 Protocols(Vec<Protocol>),
128 NotAvailable,
130}
131
132impl Message {
133 pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> {
135 match self {
136 Message::Header(HeaderLine::V1) => {
137 dest.reserve(MSG_MULTISTREAM_1_0.len());
138 dest.put(MSG_MULTISTREAM_1_0);
139 Ok(())
140 }
141 Message::Protocol(p) => {
142 let len = p.0.as_ref().len() + 1; dest.reserve(len);
144 dest.put(p.0.as_ref());
145 dest.put_u8(b'\n');
146 Ok(())
147 }
148 Message::ListProtocols => {
149 dest.reserve(MSG_LS.len());
150 dest.put(MSG_LS);
151 Ok(())
152 }
153 Message::Protocols(ps) => {
154 let mut buf = uvi::encode::usize_buffer();
155 let mut encoded = Vec::with_capacity(ps.len());
156 for p in ps {
157 encoded.extend(uvi::encode::usize(p.0.as_ref().len() + 1, &mut buf)); encoded.extend_from_slice(p.0.as_ref());
159 encoded.push(b'\n')
160 }
161 encoded.push(b'\n');
162 dest.reserve(encoded.len());
163 dest.put(encoded.as_ref());
164 Ok(())
165 }
166 Message::NotAvailable => {
167 dest.reserve(MSG_PROTOCOL_NA.len());
168 dest.put(MSG_PROTOCOL_NA);
169 Ok(())
170 }
171 }
172 }
173
174 pub fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
176 if msg == MSG_MULTISTREAM_1_0 {
177 return Ok(Message::Header(HeaderLine::V1));
178 }
179
180 if msg == MSG_PROTOCOL_NA {
181 return Ok(Message::NotAvailable);
182 }
183
184 if msg == MSG_LS {
185 return Ok(Message::ListProtocols);
186 }
187
188 if msg.first() == Some(&b'/')
191 && msg.last() == Some(&b'\n')
192 && !msg[..msg.len() - 1].contains(&b'\n')
193 {
194 let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
195 return Ok(Message::Protocol(p));
196 }
197
198 let mut protocols = Vec::new();
201 let mut remaining: &[u8] = &msg;
202 loop {
203 if remaining == [b'\n'] || remaining.is_empty() {
206 break;
207 } else if protocols.len() == MAX_PROTOCOLS {
208 return Err(ProtocolError::TooManyProtocols);
209 }
210
211 let (len, tail) = uvi::decode::usize(remaining)?;
214 if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
215 return Err(ProtocolError::InvalidMessage);
216 }
217
218 let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?;
220 protocols.push(p);
221
222 remaining = &tail[len..];
224 }
225
226 Ok(Message::Protocols(protocols))
227 }
228}
229
230pub fn encode_multistream_message(
232 messages: impl IntoIterator<Item = Message>,
233) -> crate::Result<BytesMut> {
234 let mut bytes = BytesMut::with_capacity(32);
236 let message = Message::Header(HeaderLine::V1);
237 message.encode(&mut bytes).map_err(|_| Litep2pError::InvalidData)?;
238 let mut header = UnsignedVarint::encode(bytes)?;
239
240 for message in messages {
242 let mut proto_bytes = BytesMut::with_capacity(256);
243 message.encode(&mut proto_bytes).map_err(|_| Litep2pError::InvalidData)?;
244 let mut proto_bytes = UnsignedVarint::encode(proto_bytes)?;
245 header.append(&mut proto_bytes);
246 }
247
248 Ok(BytesMut::from(&header[..]))
249}
250
251#[pin_project::pin_project]
253pub struct MessageIO<R> {
254 #[pin]
255 inner: LengthDelimited<R>,
256}
257
258impl<R> MessageIO<R> {
259 pub fn new(inner: R) -> MessageIO<R>
261 where
262 R: AsyncRead + AsyncWrite,
263 {
264 Self {
265 inner: LengthDelimited::new(inner),
266 }
267 }
268
269 pub fn into_reader(self) -> MessageReader<R> {
277 MessageReader {
278 inner: self.inner.into_reader(),
279 }
280 }
281
282 pub fn into_inner(self) -> R {
292 self.inner.into_inner()
293 }
294}
295
296impl<R> Sink<Message> for MessageIO<R>
297where
298 R: AsyncWrite,
299{
300 type Error = ProtocolError;
301
302 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303 self.project().inner.poll_ready(cx).map_err(From::from)
304 }
305
306 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
307 let mut buf = BytesMut::new();
308 item.encode(&mut buf)?;
309 self.project().inner.start_send(buf.freeze()).map_err(From::from)
310 }
311
312 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
313 self.project().inner.poll_flush(cx).map_err(From::from)
314 }
315
316 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
317 self.project().inner.poll_close(cx).map_err(From::from)
318 }
319}
320
321impl<R> Stream for MessageIO<R>
322where
323 R: AsyncRead,
324{
325 type Item = Result<Message, ProtocolError>;
326
327 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
328 match poll_stream(self.project().inner, cx) {
329 Poll::Pending => Poll::Pending,
330 Poll::Ready(None) => Poll::Ready(None),
331 Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
332 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
333 }
334 }
335}
336
337#[pin_project::pin_project]
340#[derive(Debug)]
341pub struct MessageReader<R> {
342 #[pin]
343 inner: LengthDelimitedReader<R>,
344}
345
346impl<R> MessageReader<R> {
347 pub fn into_inner(self) -> R {
359 self.inner.into_inner()
360 }
361}
362
363impl<R> Stream for MessageReader<R>
364where
365 R: AsyncRead,
366{
367 type Item = Result<Message, ProtocolError>;
368
369 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
370 poll_stream(self.project().inner, cx)
371 }
372}
373
374impl<TInner> AsyncWrite for MessageReader<TInner>
375where
376 TInner: AsyncWrite,
377{
378 fn poll_write(
379 self: Pin<&mut Self>,
380 cx: &mut Context<'_>,
381 buf: &[u8],
382 ) -> Poll<Result<usize, io::Error>> {
383 self.project().inner.poll_write(cx, buf)
384 }
385
386 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
387 self.project().inner.poll_flush(cx)
388 }
389
390 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
391 self.project().inner.poll_close(cx)
392 }
393
394 fn poll_write_vectored(
395 self: Pin<&mut Self>,
396 cx: &mut Context<'_>,
397 bufs: &[IoSlice<'_>],
398 ) -> Poll<Result<usize, io::Error>> {
399 self.project().inner.poll_write_vectored(cx, bufs)
400 }
401}
402
403fn poll_stream<S>(
404 stream: Pin<&mut S>,
405 cx: &mut Context<'_>,
406) -> Poll<Option<Result<Message, ProtocolError>>>
407where
408 S: Stream<Item = Result<Bytes, io::Error>>,
409{
410 let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
411 match Message::decode(msg) {
412 Ok(m) => m,
413 Err(err) => return Poll::Ready(Some(Err(err))),
414 }
415 } else {
416 return Poll::Ready(None);
417 };
418
419 tracing::trace!(target: LOG_TARGET, "Received message: {:?}", msg);
420
421 Poll::Ready(Some(Ok(msg)))
422}
423
424#[derive(Debug, thiserror::Error)]
426pub enum ProtocolError {
427 #[error("I/O error: `{0}`")]
429 IoError(#[from] io::Error),
430
431 #[error("Received an invalid message from the remote.")]
433 InvalidMessage,
434
435 #[error("A protocol (name) is invalid.")]
437 InvalidProtocol,
438
439 #[error("Too many protocols have been returned by the remote.")]
441 TooManyProtocols,
442
443 #[error("The protocol is not supported.")]
445 ProtocolNotSupported,
446}
447
448impl PartialEq for ProtocolError {
449 fn eq(&self, other: &Self) -> bool {
450 match (self, other) {
451 (ProtocolError::IoError(lhs), ProtocolError::IoError(rhs)) => lhs.kind() == rhs.kind(),
452 _ => std::mem::discriminant(self) == std::mem::discriminant(other),
453 }
454 }
455}
456
457impl From<ProtocolError> for io::Error {
458 fn from(err: ProtocolError) -> Self {
459 if let ProtocolError::IoError(e) = err {
460 return e;
461 }
462 io::ErrorKind::InvalidData.into()
463 }
464}
465
466impl From<uvi::decode::Error> for ProtocolError {
467 fn from(err: uvi::decode::Error) -> ProtocolError {
468 Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
469 }
470}