1use crate::yamux::{
12 chunks::Chunks,
13 connection::{self, StreamCommand},
14 frame::{
15 header::{Data, Header, StreamId, WindowUpdate, ACK},
16 Frame,
17 },
18 Config, WindowUpdateMode, DEFAULT_CREDIT,
19};
20use futures::{
21 channel::mpsc,
22 future::Either,
23 io::{AsyncRead, AsyncWrite},
24 ready, SinkExt,
25};
26use parking_lot::{Mutex, MutexGuard};
27use std::{
28 convert::TryInto,
29 fmt, io,
30 pin::Pin,
31 sync::Arc,
32 task::{Context, Poll, Waker},
33};
34
35const LOG_TARGET: &str = "litep2p::yamux";
37
38#[derive(Copy, Clone, Debug, PartialEq, Eq)]
40pub enum State {
41 Open {
43 acknowledged: bool,
53 },
54 SendClosed,
56 RecvClosed,
58 Closed,
60}
61
62impl State {
63 pub fn can_read(self) -> bool {
65 !matches!(self, State::RecvClosed | State::Closed)
66 }
67
68 pub fn can_write(self) -> bool {
70 !matches!(self, State::SendClosed | State::Closed)
71 }
72}
73
74#[derive(Copy, Clone, Debug, PartialEq, Eq)]
76pub(crate) enum Flag {
77 None,
79 Syn,
81 Ack,
83}
84
85pub struct Stream {
93 id: StreamId,
94 conn: connection::Id,
95 config: Arc<Config>,
96 sender: mpsc::Sender<StreamCommand>,
97 flag: Flag,
98 shared: Arc<Mutex<Shared>>,
99}
100
101impl fmt::Debug for Stream {
102 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
103 f.debug_struct("Stream")
104 .field("id", &self.id.val())
105 .field("connection", &self.conn)
106 .finish()
107 }
108}
109
110impl fmt::Display for Stream {
111 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112 write!(f, "(Stream {}/{})", self.conn, self.id.val())
113 }
114}
115
116impl Stream {
117 pub(crate) fn new_inbound(
118 id: StreamId,
119 conn: connection::Id,
120 config: Arc<Config>,
121 credit: u32,
122 sender: mpsc::Sender<StreamCommand>,
123 ) -> Self {
124 Self {
125 id,
126 conn,
127 config: config.clone(),
128 sender,
129 flag: Flag::None,
130 shared: Arc::new(Mutex::new(Shared::new(DEFAULT_CREDIT, credit, config))),
131 }
132 }
133
134 pub(crate) fn new_outbound(
135 id: StreamId,
136 conn: connection::Id,
137 config: Arc<Config>,
138 window: u32,
139 sender: mpsc::Sender<StreamCommand>,
140 ) -> Self {
141 Self {
142 id,
143 conn,
144 config: config.clone(),
145 sender,
146 flag: Flag::None,
147 shared: Arc::new(Mutex::new(Shared::new(window, DEFAULT_CREDIT, config))),
148 }
149 }
150
151 pub fn id(&self) -> StreamId {
153 self.id
154 }
155
156 pub fn is_write_closed(&self) -> bool {
157 matches!(self.shared().state(), State::SendClosed)
158 }
159
160 pub fn is_closed(&self) -> bool {
161 matches!(self.shared().state(), State::Closed)
162 }
163
164 pub fn is_pending_ack(&self) -> bool {
166 self.shared().is_pending_ack()
167 }
168
169 pub(crate) fn set_flag(&mut self, flag: Flag) {
171 self.flag = flag
172 }
173
174 pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> {
175 self.shared.lock()
176 }
177
178 pub(crate) fn clone_shared(&self) -> Arc<Mutex<Shared>> {
179 self.shared.clone()
180 }
181
182 fn write_zero_err(&self) -> io::Error {
183 let msg = format!("{}/{}: connection is closed", self.conn, self.id);
184 io::Error::new(io::ErrorKind::WriteZero, msg)
185 }
186
187 fn add_flag(&mut self, header: &mut Header<Either<Data, WindowUpdate>>) {
189 match self.flag {
190 Flag::None => (),
191 Flag::Syn => {
192 header.syn();
193 self.flag = Flag::None
194 }
195 Flag::Ack => {
196 header.ack();
197 self.flag = Flag::None
198 }
199 }
200 }
201
202 fn send_window_update(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
205 if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) {
208 return Poll::Ready(Ok(()));
209 }
210
211 let mut shared = self.shared.lock();
212
213 if let Some(credit) = shared.next_window_update() {
214 ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?);
215
216 shared.window += credit;
217 drop(shared);
218
219 let mut frame = Frame::window_update(self.id, credit).right();
220 self.add_flag(frame.header_mut());
221 let cmd = StreamCommand::SendFrame(frame);
222 self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?;
223 }
224
225 Poll::Ready(Ok(()))
226 }
227}
228
229#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
231pub struct Packet(Vec<u8>);
232
233impl AsRef<[u8]> for Packet {
234 fn as_ref(&self) -> &[u8] {
235 self.0.as_ref()
236 }
237}
238
239impl futures::stream::Stream for Stream {
240 type Item = io::Result<Packet>;
241
242 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
243 if !self.config.read_after_close && self.sender.is_closed() {
244 return Poll::Ready(None);
245 }
246
247 match self.send_window_update(cx) {
248 Poll::Ready(Ok(())) => {}
249 Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
250 Poll::Pending => {}
252 }
253
254 let mut shared = self.shared();
255
256 if let Some(bytes) = shared.buffer.pop() {
257 let off = bytes.offset();
258 let mut vec = bytes.into_vec();
259 if off != 0 {
260 tracing::debug!(
265 target: LOG_TARGET,
266 "{}/{}: chunk has been partially consumed",
267 self.conn,
268 self.id
269 );
270 vec = vec.split_off(off)
271 }
272 return Poll::Ready(Some(Ok(Packet(vec))));
273 }
274
275 if !shared.state().can_read() {
277 tracing::debug!(target: LOG_TARGET, "{}/{}: eof", self.conn, self.id);
278 return Poll::Ready(None); }
280
281 shared.reader = Some(cx.waker().clone());
284
285 Poll::Pending
286 }
287}
288
289impl AsyncRead for Stream {
292 fn poll_read(
293 mut self: Pin<&mut Self>,
294 cx: &mut Context,
295 buf: &mut [u8],
296 ) -> Poll<io::Result<usize>> {
297 if !self.config.read_after_close && self.sender.is_closed() {
298 return Poll::Ready(Ok(0));
299 }
300
301 match self.send_window_update(cx) {
302 Poll::Ready(Ok(())) => {}
303 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
304 Poll::Pending => {}
306 }
307
308 let mut shared = self.shared();
310 let mut n = 0;
311 while let Some(chunk) = shared.buffer.front_mut() {
312 if chunk.is_empty() {
313 shared.buffer.pop();
314 continue;
315 }
316 let k = std::cmp::min(chunk.len(), buf.len() - n);
317 buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
318 n += k;
319 chunk.advance(k);
320 if n == buf.len() {
321 break;
322 }
323 }
324
325 if n > 0 {
326 tracing::trace!(target: LOG_TARGET,"{}/{}: read {} bytes", self.conn, self.id, n);
327 return Poll::Ready(Ok(n));
328 }
329
330 if !shared.state().can_read() {
332 tracing::debug!(target: LOG_TARGET,"{}/{}: eof", self.conn, self.id);
333 return Poll::Ready(Ok(0)); }
335
336 shared.reader = Some(cx.waker().clone());
339
340 Poll::Pending
341 }
342}
343
344impl AsyncWrite for Stream {
345 fn poll_write(
346 mut self: Pin<&mut Self>,
347 cx: &mut Context,
348 buf: &[u8],
349 ) -> Poll<io::Result<usize>> {
350 ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?);
351 let body = {
352 let mut shared = self.shared();
353 if !shared.state().can_write() {
354 tracing::debug!(target: LOG_TARGET,"{}/{}: can no longer write", self.conn, self.id);
355 return Poll::Ready(Err(self.write_zero_err()));
356 }
357 if shared.credit == 0 {
358 tracing::trace!(target: LOG_TARGET,"{}/{}: no more credit left", self.conn, self.id);
359 shared.writer = Some(cx.waker().clone());
360 return Poll::Pending;
361 }
362 let k = std::cmp::min(shared.credit as usize, buf.len());
363 let k = std::cmp::min(k, self.config.split_send_size);
364 shared.credit = shared.credit.saturating_sub(k as u32);
365 Vec::from(&buf[..k])
366 };
367 let n = body.len();
368 let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left();
369 self.add_flag(frame.header_mut());
370 tracing::trace!(target: LOG_TARGET,"{}/{}: write {} bytes", self.conn, self.id, n);
371
372 if frame.header().flags().contains(ACK) {
378 self.shared()
379 .update_state(self.conn, self.id, State::Open { acknowledged: true });
380 }
381
382 let cmd = StreamCommand::SendFrame(frame);
383 self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?;
384 Poll::Ready(Ok(n))
385 }
386
387 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
388 self.sender.poll_flush_unpin(cx).map_err(|_| self.write_zero_err())
389 }
390
391 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
392 if self.is_closed() {
393 return Poll::Ready(Ok(()));
394 }
395 ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?);
396 let ack = if self.flag == Flag::Ack {
397 self.flag = Flag::None;
398 true
399 } else {
400 false
401 };
402 tracing::trace!(target: LOG_TARGET,"{}/{}: close", self.conn, self.id);
403 let cmd = StreamCommand::CloseStream { ack };
404 self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?;
405 self.shared().update_state(self.conn, self.id, State::SendClosed);
406 Poll::Ready(Ok(()))
407 }
408}
409
410#[derive(Debug)]
411pub(crate) struct Shared {
412 state: State,
413 pub(crate) window: u32,
414 pub(crate) credit: u32,
415 pub(crate) buffer: Chunks,
416 pub(crate) reader: Option<Waker>,
417 pub(crate) writer: Option<Waker>,
418 config: Arc<Config>,
419}
420
421impl Shared {
422 fn new(window: u32, credit: u32, config: Arc<Config>) -> Self {
423 Shared {
424 state: State::Open {
425 acknowledged: false,
426 },
427 window,
428 credit,
429 buffer: Chunks::new(),
430 reader: None,
431 writer: None,
432 config,
433 }
434 }
435
436 pub(crate) fn state(&self) -> State {
437 self.state
438 }
439
440 pub(crate) fn update_state(
442 &mut self,
443 cid: connection::Id,
444 sid: StreamId,
445 next: State,
446 ) -> State {
447 use self::State::*;
448
449 let current = self.state;
450
451 match (current, next) {
452 (Closed, _) => {}
453 (Open { .. }, _) => self.state = next,
454 (RecvClosed, Closed) => self.state = Closed,
455 (RecvClosed, Open { .. }) => {}
456 (RecvClosed, RecvClosed) => {}
457 (RecvClosed, SendClosed) => self.state = Closed,
458 (SendClosed, Closed) => self.state = Closed,
459 (SendClosed, Open { .. }) => {}
460 (SendClosed, RecvClosed) => self.state = Closed,
461 (SendClosed, SendClosed) => {}
462 }
463
464 tracing::trace!(target: LOG_TARGET,
465 "{}/{}: update state: (from {:?} to {:?} -> {:?})",
466 cid,
467 sid,
468 current,
469 next,
470 self.state
471 );
472
473 current }
475
476 pub(crate) fn next_window_update(&mut self) -> Option<u32> {
484 if !self.state.can_read() {
485 return None;
486 }
487
488 let new_credit = match self.config.window_update_mode {
489 WindowUpdateMode::OnReceive => {
490 debug_assert!(self.config.receive_window >= self.window);
491
492 self.config.receive_window.saturating_sub(self.window)
493 }
494 WindowUpdateMode::OnRead => {
495 debug_assert!(self.config.receive_window >= self.window);
496 let bytes_received = self.config.receive_window.saturating_sub(self.window);
497 let buffer_len: u32 = self.buffer.len().try_into().unwrap_or(u32::MAX);
498
499 bytes_received.saturating_sub(buffer_len)
500 }
501 };
502
503 if new_credit >= self.config.receive_window / 2 {
509 Some(new_credit)
510 } else {
511 None
512 }
513 }
514
515 pub fn is_pending_ack(&self) -> bool {
517 matches!(
518 self.state(),
519 State::Open {
520 acknowledged: false
521 }
522 )
523 }
524}