litep2p/transport/websocket/
stream.rs1use bytes::{Buf, Bytes};
25use futures::{SinkExt, StreamExt};
26use tokio::io::{AsyncRead, AsyncWrite};
27use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
28
29use std::{
30 pin::Pin,
31 task::{Context, Poll},
32};
33
34enum State {
38 Poisoned,
40
41 ReadyToSend,
43
44 ReadyPending { to_write: Vec<u8> },
46
47 FlushPending,
49}
50
51pub(super) struct BufferedStream<S: AsyncRead + AsyncWrite + Unpin> {
53 write_buffer: Vec<u8>,
55
56 write_ptr: usize,
58
59 read_buffer: Option<Bytes>,
61
62 stream: WebSocketStream<S>,
64
65 state: State,
67}
68
69impl<S: AsyncRead + AsyncWrite + Unpin> BufferedStream<S> {
70 pub(super) fn new(stream: WebSocketStream<S>) -> Self {
72 Self {
73 write_buffer: Vec::with_capacity(2000),
74 read_buffer: None,
75 write_ptr: 0usize,
76 stream,
77 state: State::ReadyToSend,
78 }
79 }
80}
81
82impl<S: AsyncRead + AsyncWrite + Unpin> futures::AsyncWrite for BufferedStream<S> {
83 fn poll_write(
84 mut self: Pin<&mut Self>,
85 _cx: &mut Context<'_>,
86 buf: &[u8],
87 ) -> Poll<std::io::Result<usize>> {
88 self.write_buffer.extend_from_slice(buf);
89 self.write_ptr += buf.len();
90
91 Poll::Ready(Ok(buf.len()))
92 }
93
94 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
95 if self.write_buffer.is_empty() {
96 return self
97 .stream
98 .poll_ready_unpin(cx)
99 .map_err(|_| std::io::ErrorKind::UnexpectedEof.into());
100 }
101
102 loop {
103 match std::mem::replace(&mut self.state, State::Poisoned) {
104 State::ReadyToSend => {
105 let message = self.write_buffer[..self.write_ptr].to_vec();
106 self.state = State::ReadyPending { to_write: message };
107
108 match futures::ready!(self.stream.poll_ready_unpin(cx)) {
109 Ok(()) => continue,
110 Err(_error) => {
111 return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into()));
112 }
113 }
114 }
115 State::ReadyPending { to_write } => {
116 match self.stream.start_send_unpin(Message::Binary(to_write.clone())) {
117 Ok(_) => {
118 self.state = State::FlushPending;
119 continue;
120 }
121 Err(_error) =>
122 return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
123 }
124 }
125 State::FlushPending => match futures::ready!(self.stream.poll_flush_unpin(cx)) {
126 Ok(_res) => {
127 self.state = State::ReadyToSend;
129 self.write_ptr = 0;
130 self.write_buffer = Vec::with_capacity(2000);
131 return Poll::Ready(Ok(()));
132 }
133 Err(_) => return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
134 },
135 State::Poisoned =>
136 return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
137 }
138 }
139 }
140
141 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
142 match futures::ready!(self.stream.poll_close_unpin(cx)) {
143 Ok(_) => Poll::Ready(Ok(())),
144 Err(_) => Poll::Ready(Err(std::io::ErrorKind::PermissionDenied.into())),
145 }
146 }
147}
148
149impl<S: AsyncRead + AsyncWrite + Unpin> futures::AsyncRead for BufferedStream<S> {
150 fn poll_read(
151 mut self: Pin<&mut Self>,
152 cx: &mut Context<'_>,
153 buf: &mut [u8],
154 ) -> Poll<std::io::Result<usize>> {
155 loop {
156 if self.read_buffer.is_none() {
157 match self.stream.poll_next_unpin(cx) {
158 Poll::Ready(Some(Ok(chunk))) => match chunk {
159 Message::Binary(chunk) => self.read_buffer.replace(chunk.into()),
160 _event => return Poll::Ready(Err(std::io::ErrorKind::Unsupported.into())),
161 },
162 Poll::Ready(Some(Err(_error))) =>
163 return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
164 Poll::Ready(None) => return Poll::Ready(Ok(0)),
165 Poll::Pending => return Poll::Pending,
166 };
167 }
168
169 let buffer = self.read_buffer.as_mut().expect("buffer to exist");
170 let bytes_read = buf.len().min(buffer.len());
171 let _orig_size = buffer.len();
172 buf[..bytes_read].copy_from_slice(&buffer[..bytes_read]);
173
174 buffer.advance(bytes_read);
175
176 if !buffer.is_empty() || bytes_read != 0 {
178 return Poll::Ready(Ok(bytes_read));
179 } else {
180 self.read_buffer.take();
181 }
182 }
183 }
184}