yamux/connection/
stream.rs

1// Copyright (c) 2018-2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11use crate::connection::rtt::Rtt;
12use crate::frame::header::ACK;
13use crate::ConnectionError;
14use crate::{
15    chunks::Chunks,
16    connection::{self, rtt, StreamCommand},
17    frame::{
18        header::{Data, Header, StreamId, WindowUpdate},
19        Frame,
20    },
21    Config, DEFAULT_CREDIT,
22};
23use flow_control::FlowController;
24use futures::{
25    channel::mpsc,
26    future::Either,
27    io::{AsyncRead, AsyncWrite},
28    ready, SinkExt,
29};
30use parking_lot::{Mutex, MutexGuard};
31use std::{
32    fmt, io,
33    pin::Pin,
34    sync::Arc,
35    task::{Context, Poll, Waker},
36};
37
38mod flow_control;
39
40/// The state of a Yamux stream.
41#[derive(Copy, Clone, Debug, PartialEq, Eq)]
42pub enum State {
43    /// Open bidirectionally.
44    Open {
45        /// Whether the stream is acknowledged.
46        ///
47        /// For outbound streams, this tracks whether the remote has acknowledged our stream.
48        /// For inbound streams, this tracks whether we have acknowledged the stream to the remote.
49        ///
50        /// This starts out with `false` and is set to `true` when we receive or send an `ACK` flag for this stream.
51        /// We may also directly transition:
52        /// - from `Open` to `RecvClosed` if the remote immediately sends `FIN`.
53        /// - from `Open` to `Closed` if the remote immediately sends `RST`.
54        acknowledged: bool,
55    },
56    /// Open for incoming messages.
57    SendClosed,
58    /// Open for outgoing messages.
59    RecvClosed,
60    /// Closed (terminal state).
61    Closed,
62}
63
64impl State {
65    /// Can we receive messages over this stream?
66    pub fn can_read(self) -> bool {
67        !matches!(self, State::RecvClosed | State::Closed)
68    }
69
70    /// Can we send messages over this stream?
71    pub fn can_write(self) -> bool {
72        !matches!(self, State::SendClosed | State::Closed)
73    }
74}
75
76/// Indicate if a flag still needs to be set on an outbound header.
77#[derive(Copy, Clone, Debug, PartialEq, Eq)]
78pub(crate) enum Flag {
79    /// No flag needs to be set.
80    None,
81    /// The stream was opened lazily, so set the initial SYN flag.
82    Syn,
83    /// The stream still needs acknowledgement, so set the ACK flag.
84    Ack,
85}
86
87/// A multiplexed Yamux stream.
88///
89/// Streams are created either outbound via [`crate::Connection::poll_new_outbound`]
90/// or inbound via [`crate::Connection::poll_next_inbound`].
91///
92/// `Stream` implements [`AsyncRead`] and [`AsyncWrite`] and also
93/// [`futures::stream::Stream`].
94pub struct Stream {
95    id: StreamId,
96    conn: connection::Id,
97    config: Arc<Config>,
98    sender: mpsc::Sender<StreamCommand>,
99    flag: Flag,
100    shared: Arc<Mutex<Shared>>,
101}
102
103impl fmt::Debug for Stream {
104    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105        f.debug_struct("Stream")
106            .field("id", &self.id.val())
107            .field("connection", &self.conn)
108            .finish()
109    }
110}
111
112impl fmt::Display for Stream {
113    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114        write!(f, "(Stream {}/{})", self.conn, self.id.val())
115    }
116}
117
118impl Stream {
119    pub(crate) fn new_inbound(
120        id: StreamId,
121        conn: connection::Id,
122        config: Arc<Config>,
123        send_window: u32,
124        sender: mpsc::Sender<StreamCommand>,
125        rtt: rtt::Rtt,
126        accumulated_max_stream_windows: Arc<Mutex<usize>>,
127    ) -> Self {
128        Self {
129            id,
130            conn,
131            config: config.clone(),
132            sender,
133            flag: Flag::Ack,
134            shared: Arc::new(Mutex::new(Shared::new(
135                DEFAULT_CREDIT,
136                send_window,
137                accumulated_max_stream_windows,
138                rtt,
139                config,
140            ))),
141        }
142    }
143
144    pub(crate) fn new_outbound(
145        id: StreamId,
146        conn: connection::Id,
147        config: Arc<Config>,
148        sender: mpsc::Sender<StreamCommand>,
149        rtt: rtt::Rtt,
150        accumulated_max_stream_windows: Arc<Mutex<usize>>,
151    ) -> Self {
152        Self {
153            id,
154            conn,
155            config: config.clone(),
156            sender,
157            flag: Flag::Syn,
158            shared: Arc::new(Mutex::new(Shared::new(
159                DEFAULT_CREDIT,
160                DEFAULT_CREDIT,
161                accumulated_max_stream_windows,
162                rtt,
163                config,
164            ))),
165        }
166    }
167
168    /// Get this stream's identifier.
169    pub fn id(&self) -> StreamId {
170        self.id
171    }
172
173    pub fn is_write_closed(&self) -> bool {
174        matches!(self.shared().state(), State::SendClosed)
175    }
176
177    pub fn is_closed(&self) -> bool {
178        matches!(self.shared().state(), State::Closed)
179    }
180
181    /// Whether we are still waiting for the remote to acknowledge this stream.
182    pub fn is_pending_ack(&self) -> bool {
183        self.shared().is_pending_ack()
184    }
185
186    pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> {
187        self.shared.lock()
188    }
189
190    pub(crate) fn clone_shared(&self) -> Arc<Mutex<Shared>> {
191        self.shared.clone()
192    }
193
194    fn write_zero_err(&self) -> io::Error {
195        let msg = format!("{}/{}: connection is closed", self.conn, self.id);
196        io::Error::new(io::ErrorKind::WriteZero, msg)
197    }
198
199    /// Set ACK or SYN flag if necessary.
200    fn add_flag(&mut self, header: &mut Header<Either<Data, WindowUpdate>>) {
201        match self.flag {
202            Flag::None => (),
203            Flag::Syn => {
204                header.syn();
205                self.flag = Flag::None
206            }
207            Flag::Ack => {
208                header.ack();
209                self.flag = Flag::None
210            }
211        }
212    }
213
214    /// Send new credit to the sending side via a window update message if
215    /// permitted.
216    fn send_window_update(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
217        if !self.shared.lock().state.can_read() {
218            return Poll::Ready(Ok(()));
219        }
220
221        ready!(self
222            .sender
223            .poll_ready(cx)
224            .map_err(|_| self.write_zero_err())?);
225
226        let Some(credit) = self.shared.lock().next_window_update() else {
227            return Poll::Ready(Ok(()));
228        };
229
230        let mut frame = Frame::window_update(self.id, credit).right();
231        self.add_flag(frame.header_mut());
232        let cmd = StreamCommand::SendFrame(frame);
233        self.sender
234            .start_send(cmd)
235            .map_err(|_| self.write_zero_err())?;
236
237        Poll::Ready(Ok(()))
238    }
239}
240
241/// Byte data produced by the [`futures::stream::Stream`] impl of [`Stream`].
242#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
243pub struct Packet(Vec<u8>);
244
245impl AsRef<[u8]> for Packet {
246    fn as_ref(&self) -> &[u8] {
247        self.0.as_ref()
248    }
249}
250
251impl futures::stream::Stream for Stream {
252    type Item = io::Result<Packet>;
253
254    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
255        if !self.config.read_after_close && self.sender.is_closed() {
256            return Poll::Ready(None);
257        }
258
259        match self.send_window_update(cx) {
260            Poll::Ready(Ok(())) => {}
261            Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
262            // Continue reading buffered data even though sending a window update blocked.
263            Poll::Pending => {}
264        }
265
266        let mut shared = self.shared();
267
268        if let Some(bytes) = shared.buffer.pop() {
269            let off = bytes.offset();
270            let mut vec = bytes.into_vec();
271            if off != 0 {
272                // This should generally not happen when the stream is used only as
273                // a `futures::stream::Stream` since the whole point of this impl is
274                // to consume chunks atomically. It may perhaps happen when mixing
275                // this impl and the `AsyncRead` one.
276                log::debug!(
277                    "{}/{}: chunk has been partially consumed",
278                    self.conn,
279                    self.id
280                );
281                vec = vec.split_off(off)
282            }
283            return Poll::Ready(Some(Ok(Packet(vec))));
284        }
285
286        // Buffer is empty, let's check if we can expect to read more data.
287        if !shared.state().can_read() {
288            log::debug!("{}/{}: eof", self.conn, self.id);
289            return Poll::Ready(None); // stream has been reset
290        }
291
292        // Since we have no more data at this point, we want to be woken up
293        // by the connection when more becomes available for us.
294        shared.reader = Some(cx.waker().clone());
295
296        Poll::Pending
297    }
298}
299
300// Like the `futures::stream::Stream` impl above, but copies bytes into the
301// provided mutable slice.
302impl AsyncRead for Stream {
303    fn poll_read(
304        mut self: Pin<&mut Self>,
305        cx: &mut Context,
306        buf: &mut [u8],
307    ) -> Poll<io::Result<usize>> {
308        if !self.config.read_after_close && self.sender.is_closed() {
309            return Poll::Ready(Ok(0));
310        }
311
312        match self.send_window_update(cx) {
313            Poll::Ready(Ok(())) => {}
314            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
315            // Continue reading buffered data even though sending a window update blocked.
316            Poll::Pending => {}
317        }
318
319        // Copy data from stream buffer.
320        let mut shared = self.shared();
321        let mut n = 0;
322        while let Some(chunk) = shared.buffer.front_mut() {
323            if chunk.is_empty() {
324                shared.buffer.pop();
325                continue;
326            }
327            let k = std::cmp::min(chunk.len(), buf.len() - n);
328            buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
329            n += k;
330            chunk.advance(k);
331            if n == buf.len() {
332                break;
333            }
334        }
335
336        if n > 0 {
337            log::trace!("{}/{}: read {} bytes", self.conn, self.id, n);
338            return Poll::Ready(Ok(n));
339        }
340
341        // Buffer is empty, let's check if we can expect to read more data.
342        if !shared.state().can_read() {
343            log::debug!("{}/{}: eof", self.conn, self.id);
344            return Poll::Ready(Ok(0)); // stream has been reset
345        }
346
347        // Since we have no more data at this point, we want to be woken up
348        // by the connection when more becomes available for us.
349        shared.reader = Some(cx.waker().clone());
350
351        Poll::Pending
352    }
353}
354
355impl AsyncWrite for Stream {
356    fn poll_write(
357        mut self: Pin<&mut Self>,
358        cx: &mut Context,
359        buf: &[u8],
360    ) -> Poll<io::Result<usize>> {
361        ready!(self
362            .sender
363            .poll_ready(cx)
364            .map_err(|_| self.write_zero_err())?);
365        let body = {
366            let mut shared = self.shared();
367            if !shared.state().can_write() {
368                log::debug!("{}/{}: can no longer write", self.conn, self.id);
369                return Poll::Ready(Err(self.write_zero_err()));
370            }
371            if shared.send_window() == 0 {
372                log::trace!("{}/{}: no more credit left", self.conn, self.id);
373                shared.writer = Some(cx.waker().clone());
374                return Poll::Pending;
375            }
376            let k = std::cmp::min(shared.send_window() as usize, buf.len());
377            let k = std::cmp::min(k, self.config.split_send_size);
378            shared
379                .consume_send_window(k as u32)
380                .expect("not exceed receive window");
381            Vec::from(&buf[..k])
382        };
383        let n = body.len();
384        let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left();
385        self.add_flag(frame.header_mut());
386        log::trace!("{}/{}: write {} bytes", self.conn, self.id, n);
387
388        // technically, the frame hasn't been sent yet on the wire but from the perspective of this data structure, we've queued the frame for sending
389        // We are tracking this information:
390        // a) to be consistent with outbound streams
391        // b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test.
392        if frame.header().flags().contains(ACK) {
393            self.shared()
394                .update_state(self.conn, self.id, State::Open { acknowledged: true });
395        }
396
397        let cmd = StreamCommand::SendFrame(frame);
398        self.sender
399            .start_send(cmd)
400            .map_err(|_| self.write_zero_err())?;
401        Poll::Ready(Ok(n))
402    }
403
404    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
405        self.sender
406            .poll_flush_unpin(cx)
407            .map_err(|_| self.write_zero_err())
408    }
409
410    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
411        if self.is_closed() {
412            return Poll::Ready(Ok(()));
413        }
414        ready!(self
415            .sender
416            .poll_ready(cx)
417            .map_err(|_| self.write_zero_err())?);
418        let ack = if self.flag == Flag::Ack {
419            self.flag = Flag::None;
420            true
421        } else {
422            false
423        };
424        log::trace!("{}/{}: close", self.conn, self.id);
425        let cmd = StreamCommand::CloseStream { ack };
426        self.sender
427            .start_send(cmd)
428            .map_err(|_| self.write_zero_err())?;
429        self.shared()
430            .update_state(self.conn, self.id, State::SendClosed);
431        Poll::Ready(Ok(()))
432    }
433}
434
435#[derive(Debug)]
436pub(crate) struct Shared {
437    state: State,
438    flow_controller: FlowController,
439    pub(crate) buffer: Chunks,
440    pub(crate) reader: Option<Waker>,
441    pub(crate) writer: Option<Waker>,
442}
443
444impl Shared {
445    fn new(
446        receive_window: u32,
447        send_window: u32,
448        accumulated_max_stream_windows: Arc<Mutex<usize>>,
449        rtt: Rtt,
450        config: Arc<Config>,
451    ) -> Self {
452        Shared {
453            state: State::Open {
454                acknowledged: false,
455            },
456            flow_controller: FlowController::new(
457                receive_window,
458                send_window,
459                accumulated_max_stream_windows,
460                rtt,
461                config,
462            ),
463            buffer: Chunks::new(),
464            reader: None,
465            writer: None,
466        }
467    }
468
469    pub(crate) fn state(&self) -> State {
470        self.state
471    }
472
473    /// Update the stream state and return the state before it was updated.
474    pub(crate) fn update_state(
475        &mut self,
476        cid: connection::Id,
477        sid: StreamId,
478        next: State,
479    ) -> State {
480        use self::State::*;
481
482        let current = self.state;
483
484        match (current, next) {
485            (Closed, _) => {}
486            (Open { .. }, _) => self.state = next,
487            (RecvClosed, Closed) => self.state = Closed,
488            (RecvClosed, Open { .. }) => {}
489            (RecvClosed, RecvClosed) => {}
490            (RecvClosed, SendClosed) => self.state = Closed,
491            (SendClosed, Closed) => self.state = Closed,
492            (SendClosed, Open { .. }) => {}
493            (SendClosed, RecvClosed) => self.state = Closed,
494            (SendClosed, SendClosed) => {}
495        }
496
497        log::trace!(
498            "{}/{}: update state: (from {:?} to {:?} -> {:?})",
499            cid,
500            sid,
501            current,
502            next,
503            self.state
504        );
505
506        current // Return the previous stream state for informational purposes.
507    }
508
509    pub(crate) fn next_window_update(&mut self) -> Option<u32> {
510        self.flow_controller.next_window_update(self.buffer.len())
511    }
512
513    /// Whether we are still waiting for the remote to acknowledge this stream.
514    pub fn is_pending_ack(&self) -> bool {
515        matches!(
516            self.state(),
517            State::Open {
518                acknowledged: false
519            }
520        )
521    }
522
523    pub(crate) fn send_window(&self) -> u32 {
524        self.flow_controller.send_window()
525    }
526
527    pub(crate) fn consume_send_window(&mut self, i: u32) -> Result<(), ConnectionError> {
528        self.flow_controller.consume_send_window(i)
529    }
530
531    pub(crate) fn increase_send_window_by(&mut self, i: u32) -> Result<(), ConnectionError> {
532        self.flow_controller.increase_send_window_by(i)
533    }
534
535    pub(crate) fn consume_receive_window(&mut self, i: u32) -> Result<(), ConnectionError> {
536        self.flow_controller.consume_receive_window(i)
537    }
538}