yamux/connection/
stream.rs

1// Copyright (c) 2018-2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11use crate::connection::rtt::Rtt;
12use crate::frame::header::ACK;
13use crate::{
14    chunks::Chunks,
15    connection::{self, rtt, StreamCommand},
16    frame::{
17        header::{Data, Header, StreamId, WindowUpdate},
18        Frame,
19    },
20    Config, DEFAULT_CREDIT,
21};
22use flow_control::FlowController;
23use futures::{
24    channel::mpsc,
25    future::Either,
26    io::{AsyncRead, AsyncWrite},
27    ready, SinkExt,
28};
29use parking_lot::{Mutex, MutexGuard};
30use std::{
31    fmt, io,
32    pin::Pin,
33    sync::Arc,
34    task::{Context, Poll, Waker},
35};
36
37mod flow_control;
38
39/// The state of a Yamux stream.
40#[derive(Copy, Clone, Debug, PartialEq, Eq)]
41pub enum State {
42    /// Open bidirectionally.
43    Open {
44        /// Whether the stream is acknowledged.
45        ///
46        /// For outbound streams, this tracks whether the remote has acknowledged our stream.
47        /// For inbound streams, this tracks whether we have acknowledged the stream to the remote.
48        ///
49        /// This starts out with `false` and is set to `true` when we receive or send an `ACK` flag for this stream.
50        /// We may also directly transition:
51        /// - from `Open` to `RecvClosed` if the remote immediately sends `FIN`.
52        /// - from `Open` to `Closed` if the remote immediately sends `RST`.
53        acknowledged: bool,
54    },
55    /// Open for incoming messages.
56    SendClosed,
57    /// Open for outgoing messages.
58    RecvClosed,
59    /// Closed (terminal state).
60    Closed,
61}
62
63impl State {
64    /// Can we receive messages over this stream?
65    pub fn can_read(self) -> bool {
66        !matches!(self, State::RecvClosed | State::Closed)
67    }
68
69    /// Can we send messages over this stream?
70    pub fn can_write(self) -> bool {
71        !matches!(self, State::SendClosed | State::Closed)
72    }
73}
74
75/// Indicate if a flag still needs to be set on an outbound header.
76#[derive(Copy, Clone, Debug, PartialEq, Eq)]
77pub(crate) enum Flag {
78    /// No flag needs to be set.
79    None,
80    /// The stream was opened lazily, so set the initial SYN flag.
81    Syn,
82    /// The stream still needs acknowledgement, so set the ACK flag.
83    Ack,
84}
85
86/// A multiplexed Yamux stream.
87///
88/// Streams are created either outbound via [`crate::Connection::poll_new_outbound`]
89/// or inbound via [`crate::Connection::poll_next_inbound`].
90///
91/// `Stream` implements [`AsyncRead`] and [`AsyncWrite`] and also
92/// [`futures::stream::Stream`].
93pub struct Stream {
94    id: StreamId,
95    conn: connection::Id,
96    config: Arc<Config>,
97    sender: mpsc::Sender<StreamCommand>,
98    flag: Flag,
99    shared: Arc<Mutex<Shared>>,
100}
101
102impl fmt::Debug for Stream {
103    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104        f.debug_struct("Stream")
105            .field("id", &self.id.val())
106            .field("connection", &self.conn)
107            .finish()
108    }
109}
110
111impl fmt::Display for Stream {
112    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113        write!(f, "(Stream {}/{})", self.conn, self.id.val())
114    }
115}
116
117impl Stream {
118    pub(crate) fn new_inbound(
119        id: StreamId,
120        conn: connection::Id,
121        config: Arc<Config>,
122        send_window: u32,
123        sender: mpsc::Sender<StreamCommand>,
124        rtt: rtt::Rtt,
125        accumulated_max_stream_windows: Arc<Mutex<usize>>,
126    ) -> Self {
127        Self {
128            id,
129            conn,
130            config: config.clone(),
131            sender,
132            flag: Flag::Ack,
133            shared: Arc::new(Mutex::new(Shared::new(
134                DEFAULT_CREDIT,
135                send_window,
136                accumulated_max_stream_windows,
137                rtt,
138                config,
139            ))),
140        }
141    }
142
143    pub(crate) fn new_outbound(
144        id: StreamId,
145        conn: connection::Id,
146        config: Arc<Config>,
147        sender: mpsc::Sender<StreamCommand>,
148        rtt: rtt::Rtt,
149        accumulated_max_stream_windows: Arc<Mutex<usize>>,
150    ) -> Self {
151        Self {
152            id,
153            conn,
154            config: config.clone(),
155            sender,
156            flag: Flag::Syn,
157            shared: Arc::new(Mutex::new(Shared::new(
158                DEFAULT_CREDIT,
159                DEFAULT_CREDIT,
160                accumulated_max_stream_windows,
161                rtt,
162                config,
163            ))),
164        }
165    }
166
167    /// Get this stream's identifier.
168    pub fn id(&self) -> StreamId {
169        self.id
170    }
171
172    pub fn is_write_closed(&self) -> bool {
173        matches!(self.shared().state(), State::SendClosed)
174    }
175
176    pub fn is_closed(&self) -> bool {
177        matches!(self.shared().state(), State::Closed)
178    }
179
180    /// Whether we are still waiting for the remote to acknowledge this stream.
181    pub fn is_pending_ack(&self) -> bool {
182        self.shared().is_pending_ack()
183    }
184
185    pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> {
186        self.shared.lock()
187    }
188
189    pub(crate) fn clone_shared(&self) -> Arc<Mutex<Shared>> {
190        self.shared.clone()
191    }
192
193    fn write_zero_err(&self) -> io::Error {
194        let msg = format!("{}/{}: connection is closed", self.conn, self.id);
195        io::Error::new(io::ErrorKind::WriteZero, msg)
196    }
197
198    /// Set ACK or SYN flag if necessary.
199    fn add_flag(&mut self, header: &mut Header<Either<Data, WindowUpdate>>) {
200        match self.flag {
201            Flag::None => (),
202            Flag::Syn => {
203                header.syn();
204                self.flag = Flag::None
205            }
206            Flag::Ack => {
207                header.ack();
208                self.flag = Flag::None
209            }
210        }
211    }
212
213    /// Send new credit to the sending side via a window update message if
214    /// permitted.
215    fn send_window_update(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
216        if !self.shared.lock().state.can_read() {
217            return Poll::Ready(Ok(()));
218        }
219
220        ready!(self
221            .sender
222            .poll_ready(cx)
223            .map_err(|_| self.write_zero_err())?);
224
225        let Some(credit) = self.shared.lock().next_window_update() else {
226            return Poll::Ready(Ok(()));
227        };
228
229        let mut frame = Frame::window_update(self.id, credit).right();
230        self.add_flag(frame.header_mut());
231        let cmd = StreamCommand::SendFrame(frame);
232        self.sender
233            .start_send(cmd)
234            .map_err(|_| self.write_zero_err())?;
235
236        Poll::Ready(Ok(()))
237    }
238}
239
240/// Byte data produced by the [`futures::stream::Stream`] impl of [`Stream`].
241#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
242pub struct Packet(Vec<u8>);
243
244impl AsRef<[u8]> for Packet {
245    fn as_ref(&self) -> &[u8] {
246        self.0.as_ref()
247    }
248}
249
250impl futures::stream::Stream for Stream {
251    type Item = io::Result<Packet>;
252
253    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
254        if !self.config.read_after_close && self.sender.is_closed() {
255            return Poll::Ready(None);
256        }
257
258        match self.send_window_update(cx) {
259            Poll::Ready(Ok(())) => {}
260            Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
261            // Continue reading buffered data even though sending a window update blocked.
262            Poll::Pending => {}
263        }
264
265        let mut shared = self.shared();
266
267        if let Some(bytes) = shared.buffer.pop() {
268            let off = bytes.offset();
269            let mut vec = bytes.into_vec();
270            if off != 0 {
271                // This should generally not happen when the stream is used only as
272                // a `futures::stream::Stream` since the whole point of this impl is
273                // to consume chunks atomically. It may perhaps happen when mixing
274                // this impl and the `AsyncRead` one.
275                log::debug!(
276                    "{}/{}: chunk has been partially consumed",
277                    self.conn,
278                    self.id
279                );
280                vec = vec.split_off(off)
281            }
282            return Poll::Ready(Some(Ok(Packet(vec))));
283        }
284
285        // Buffer is empty, let's check if we can expect to read more data.
286        if !shared.state().can_read() {
287            log::debug!("{}/{}: eof", self.conn, self.id);
288            return Poll::Ready(None); // stream has been reset
289        }
290
291        // Since we have no more data at this point, we want to be woken up
292        // by the connection when more becomes available for us.
293        shared.reader = Some(cx.waker().clone());
294
295        Poll::Pending
296    }
297}
298
299// Like the `futures::stream::Stream` impl above, but copies bytes into the
300// provided mutable slice.
301impl AsyncRead for Stream {
302    fn poll_read(
303        mut self: Pin<&mut Self>,
304        cx: &mut Context,
305        buf: &mut [u8],
306    ) -> Poll<io::Result<usize>> {
307        if !self.config.read_after_close && self.sender.is_closed() {
308            return Poll::Ready(Ok(0));
309        }
310
311        match self.send_window_update(cx) {
312            Poll::Ready(Ok(())) => {}
313            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
314            // Continue reading buffered data even though sending a window update blocked.
315            Poll::Pending => {}
316        }
317
318        // Copy data from stream buffer.
319        let mut shared = self.shared();
320        let mut n = 0;
321        while let Some(chunk) = shared.buffer.front_mut() {
322            if chunk.is_empty() {
323                shared.buffer.pop();
324                continue;
325            }
326            let k = std::cmp::min(chunk.len(), buf.len() - n);
327            buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
328            n += k;
329            chunk.advance(k);
330            if n == buf.len() {
331                break;
332            }
333        }
334
335        if n > 0 {
336            log::trace!("{}/{}: read {} bytes", self.conn, self.id, n);
337            return Poll::Ready(Ok(n));
338        }
339
340        // Buffer is empty, let's check if we can expect to read more data.
341        if !shared.state().can_read() {
342            log::debug!("{}/{}: eof", self.conn, self.id);
343            return Poll::Ready(Ok(0)); // stream has been reset
344        }
345
346        // Since we have no more data at this point, we want to be woken up
347        // by the connection when more becomes available for us.
348        shared.reader = Some(cx.waker().clone());
349
350        Poll::Pending
351    }
352}
353
354impl AsyncWrite for Stream {
355    fn poll_write(
356        mut self: Pin<&mut Self>,
357        cx: &mut Context,
358        buf: &[u8],
359    ) -> Poll<io::Result<usize>> {
360        ready!(self
361            .sender
362            .poll_ready(cx)
363            .map_err(|_| self.write_zero_err())?);
364        let body = {
365            let mut shared = self.shared();
366            if !shared.state().can_write() {
367                log::debug!("{}/{}: can no longer write", self.conn, self.id);
368                return Poll::Ready(Err(self.write_zero_err()));
369            }
370            if shared.send_window() == 0 {
371                log::trace!("{}/{}: no more credit left", self.conn, self.id);
372                shared.writer = Some(cx.waker().clone());
373                return Poll::Pending;
374            }
375            let k = std::cmp::min(
376                shared.send_window(),
377                buf.len().try_into().unwrap_or(u32::MAX),
378            );
379            let k = std::cmp::min(
380                k,
381                self.config.split_send_size.try_into().unwrap_or(u32::MAX),
382            );
383            shared.consume_send_window(k);
384            Vec::from(&buf[..k as usize])
385        };
386        let n = body.len();
387        let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left();
388        self.add_flag(frame.header_mut());
389        log::trace!("{}/{}: write {} bytes", self.conn, self.id, n);
390
391        // technically, the frame hasn't been sent yet on the wire but from the perspective of this data structure, we've queued the frame for sending
392        // We are tracking this information:
393        // a) to be consistent with outbound streams
394        // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test.
395        if frame.header().flags().contains(ACK) {
396            self.shared()
397                .update_state(self.conn, self.id, State::Open { acknowledged: true });
398        }
399
400        let cmd = StreamCommand::SendFrame(frame);
401        self.sender
402            .start_send(cmd)
403            .map_err(|_| self.write_zero_err())?;
404        Poll::Ready(Ok(n))
405    }
406
407    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
408        self.sender
409            .poll_flush_unpin(cx)
410            .map_err(|_| self.write_zero_err())
411    }
412
413    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
414        if self.is_closed() {
415            return Poll::Ready(Ok(()));
416        }
417        ready!(self
418            .sender
419            .poll_ready(cx)
420            .map_err(|_| self.write_zero_err())?);
421        let ack = if self.flag == Flag::Ack {
422            self.flag = Flag::None;
423            true
424        } else {
425            false
426        };
427        log::trace!("{}/{}: close", self.conn, self.id);
428        let cmd = StreamCommand::CloseStream { ack };
429        self.sender
430            .start_send(cmd)
431            .map_err(|_| self.write_zero_err())?;
432        self.shared()
433            .update_state(self.conn, self.id, State::SendClosed);
434        Poll::Ready(Ok(()))
435    }
436}
437
438#[derive(Debug)]
439pub(crate) struct Shared {
440    state: State,
441    flow_controller: FlowController,
442    pub(crate) buffer: Chunks,
443    pub(crate) reader: Option<Waker>,
444    pub(crate) writer: Option<Waker>,
445}
446
447impl Shared {
448    fn new(
449        receive_window: u32,
450        send_window: u32,
451        accumulated_max_stream_windows: Arc<Mutex<usize>>,
452        rtt: Rtt,
453        config: Arc<Config>,
454    ) -> Self {
455        Shared {
456            state: State::Open {
457                acknowledged: false,
458            },
459            flow_controller: FlowController::new(
460                receive_window,
461                send_window,
462                accumulated_max_stream_windows,
463                rtt,
464                config,
465            ),
466            buffer: Chunks::new(),
467            reader: None,
468            writer: None,
469        }
470    }
471
472    pub(crate) fn state(&self) -> State {
473        self.state
474    }
475
476    /// Update the stream state and return the state before it was updated.
477    pub(crate) fn update_state(
478        &mut self,
479        cid: connection::Id,
480        sid: StreamId,
481        next: State,
482    ) -> State {
483        use self::State::*;
484
485        let current = self.state;
486
487        match (current, next) {
488            (Closed, _) => {}
489            (Open { .. }, _) => self.state = next,
490            (RecvClosed, Closed) => self.state = Closed,
491            (RecvClosed, Open { .. }) => {}
492            (RecvClosed, RecvClosed) => {}
493            (RecvClosed, SendClosed) => self.state = Closed,
494            (SendClosed, Closed) => self.state = Closed,
495            (SendClosed, Open { .. }) => {}
496            (SendClosed, RecvClosed) => self.state = Closed,
497            (SendClosed, SendClosed) => {}
498        }
499
500        log::trace!(
501            "{}/{}: update state: (from {:?} to {:?} -> {:?})",
502            cid,
503            sid,
504            current,
505            next,
506            self.state
507        );
508
509        current // Return the previous stream state for informational purposes.
510    }
511
512    pub(crate) fn next_window_update(&mut self) -> Option<u32> {
513        self.flow_controller.next_window_update(self.buffer.len())
514    }
515
516    /// Whether we are still waiting for the remote to acknowledge this stream.
517    pub fn is_pending_ack(&self) -> bool {
518        matches!(
519            self.state(),
520            State::Open {
521                acknowledged: false
522            }
523        )
524    }
525
526    pub(crate) fn send_window(&self) -> u32 {
527        self.flow_controller.send_window()
528    }
529
530    pub(crate) fn consume_send_window(&mut self, i: u32) {
531        self.flow_controller.consume_send_window(i)
532    }
533
534    pub(crate) fn increase_send_window_by(&mut self, i: u32) {
535        self.flow_controller.increase_send_window_by(i)
536    }
537
538    pub(crate) fn receive_window(&self) -> u32 {
539        self.flow_controller.receive_window()
540    }
541
542    pub(crate) fn consume_receive_window(&mut self, i: u32) {
543        self.flow_controller.consume_receive_window(i)
544    }
545}