litep2p/transport/websocket/
stream.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Stream implementation for `tokio_tungstenite::WebSocketStream` that implements
22//! `AsyncRead + AsyncWrite`
23
24use bytes::{Buf, Bytes};
25use futures::{SinkExt, StreamExt};
26use tokio::io::{AsyncRead, AsyncWrite};
27use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
28
29use std::{
30    pin::Pin,
31    task::{Context, Poll},
32};
33
34// TODO: add tests
35
36/// Send state.
37enum State {
38    /// State is poisoned.
39    Poisoned,
40
41    /// Sink is accepting input.
42    ReadyToSend,
43
44    /// Sink is ready to send.
45    ReadyPending { to_write: Vec<u8> },
46
47    /// Flush is pending for the sink.
48    FlushPending,
49}
50
51/// Buffered stream which implements `AsyncRead + AsyncWrite`
52pub(super) struct BufferedStream<S: AsyncRead + AsyncWrite + Unpin> {
53    /// Write buffer.
54    write_buffer: Vec<u8>,
55
56    /// Write pointer.
57    write_ptr: usize,
58
59    // Read buffer.
60    read_buffer: Option<Bytes>,
61
62    /// Underlying WebSocket stream.
63    stream: WebSocketStream<S>,
64
65    /// Read state.
66    state: State,
67}
68
69impl<S: AsyncRead + AsyncWrite + Unpin> BufferedStream<S> {
70    /// Create new [`BufferedStream`].
71    pub(super) fn new(stream: WebSocketStream<S>) -> Self {
72        Self {
73            write_buffer: Vec::with_capacity(2000),
74            read_buffer: None,
75            write_ptr: 0usize,
76            stream,
77            state: State::ReadyToSend,
78        }
79    }
80}
81
82impl<S: AsyncRead + AsyncWrite + Unpin> futures::AsyncWrite for BufferedStream<S> {
83    fn poll_write(
84        mut self: Pin<&mut Self>,
85        _cx: &mut Context<'_>,
86        buf: &[u8],
87    ) -> Poll<std::io::Result<usize>> {
88        self.write_buffer.extend_from_slice(buf);
89        self.write_ptr += buf.len();
90
91        Poll::Ready(Ok(buf.len()))
92    }
93
94    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
95        if self.write_buffer.is_empty() {
96            return self
97                .stream
98                .poll_ready_unpin(cx)
99                .map_err(|_| std::io::ErrorKind::UnexpectedEof.into());
100        }
101
102        loop {
103            match std::mem::replace(&mut self.state, State::Poisoned) {
104                State::ReadyToSend => {
105                    let message = self.write_buffer[..self.write_ptr].to_vec();
106                    self.state = State::ReadyPending { to_write: message };
107
108                    match futures::ready!(self.stream.poll_ready_unpin(cx)) {
109                        Ok(()) => continue,
110                        Err(_error) => {
111                            return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into()));
112                        }
113                    }
114                }
115                State::ReadyPending { to_write } => {
116                    match self.stream.start_send_unpin(Message::Binary(to_write.clone())) {
117                        Ok(_) => {
118                            self.state = State::FlushPending;
119                            continue;
120                        }
121                        Err(_error) =>
122                            return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
123                    }
124                }
125                State::FlushPending => match futures::ready!(self.stream.poll_flush_unpin(cx)) {
126                    Ok(_res) => {
127                        // TODO: optimize
128                        self.state = State::ReadyToSend;
129                        self.write_ptr = 0;
130                        self.write_buffer = Vec::with_capacity(2000);
131                        return Poll::Ready(Ok(()));
132                    }
133                    Err(_) => return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
134                },
135                State::Poisoned =>
136                    return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
137            }
138        }
139    }
140
141    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
142        match futures::ready!(self.stream.poll_close_unpin(cx)) {
143            Ok(_) => Poll::Ready(Ok(())),
144            Err(_) => Poll::Ready(Err(std::io::ErrorKind::PermissionDenied.into())),
145        }
146    }
147}
148
149impl<S: AsyncRead + AsyncWrite + Unpin> futures::AsyncRead for BufferedStream<S> {
150    fn poll_read(
151        mut self: Pin<&mut Self>,
152        cx: &mut Context<'_>,
153        buf: &mut [u8],
154    ) -> Poll<std::io::Result<usize>> {
155        loop {
156            if self.read_buffer.is_none() {
157                match self.stream.poll_next_unpin(cx) {
158                    Poll::Ready(Some(Ok(chunk))) => match chunk {
159                        Message::Binary(chunk) => self.read_buffer.replace(chunk.into()),
160                        _event => return Poll::Ready(Err(std::io::ErrorKind::Unsupported.into())),
161                    },
162                    Poll::Ready(Some(Err(_error))) =>
163                        return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
164                    Poll::Ready(None) => return Poll::Ready(Ok(0)),
165                    Poll::Pending => return Poll::Pending,
166                };
167            }
168
169            let buffer = self.read_buffer.as_mut().expect("buffer to exist");
170            let bytes_read = buf.len().min(buffer.len());
171            let _orig_size = buffer.len();
172            buf[..bytes_read].copy_from_slice(&buffer[..bytes_read]);
173
174            buffer.advance(bytes_read);
175
176            // TODO: this can't be correct
177            if !buffer.is_empty() || bytes_read != 0 {
178                return Poll::Ready(Ok(bytes_read));
179            } else {
180                self.read_buffer.take();
181            }
182        }
183    }
184}