yamux/
connection.rs

1// Copyright (c) 2018-2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11//! This module contains the `Connection` type and associated helpers.
12//! A `Connection` wraps an underlying (async) I/O resource and multiplexes
13//! `Stream`s over it.
14
15mod cleanup;
16mod closing;
17mod stream;
18
19use crate::tagged_stream::TaggedStream;
20use crate::{
21    error::ConnectionError,
22    frame::header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID},
23    frame::{self, Frame},
24    Config, WindowUpdateMode, DEFAULT_CREDIT,
25};
26use crate::{Result, MAX_ACK_BACKLOG};
27use cleanup::Cleanup;
28use closing::Closing;
29use futures::stream::SelectAll;
30use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
31use nohash_hasher::IntMap;
32use parking_lot::Mutex;
33use std::collections::VecDeque;
34use std::task::{Context, Waker};
35use std::{fmt, sync::Arc, task::Poll};
36
37pub use stream::{Packet, State, Stream};
38
39/// How the connection is used.
40#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
41pub enum Mode {
42    /// Client to server connection.
43    Client,
44    /// Server to client connection.
45    Server,
46}
47
48/// The connection identifier.
49///
50/// Randomly generated, this is mainly intended to improve log output.
51#[derive(Clone, Copy)]
52pub(crate) struct Id(u32);
53
54impl Id {
55    /// Create a random connection ID.
56    pub(crate) fn random() -> Self {
57        Id(rand::random())
58    }
59}
60
61impl fmt::Debug for Id {
62    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
63        write!(f, "{:08x}", self.0)
64    }
65}
66
67impl fmt::Display for Id {
68    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69        write!(f, "{:08x}", self.0)
70    }
71}
72
73/// A Yamux connection object.
74///
75/// Wraps the underlying I/O resource and makes progress via its
76/// [`Connection::poll_next_inbound`] method which must be called repeatedly
77/// until `Ok(None)` signals EOF or an error is encountered.
78#[derive(Debug)]
79pub struct Connection<T> {
80    inner: ConnectionState<T>,
81}
82
83impl<T: AsyncRead + AsyncWrite + Unpin> Connection<T> {
84    pub fn new(socket: T, cfg: Config, mode: Mode) -> Self {
85        Self {
86            inner: ConnectionState::Active(Active::new(socket, cfg, mode)),
87        }
88    }
89
90    /// Poll for a new outbound stream.
91    ///
92    /// This function will fail if the current state does not allow opening new outbound streams.
93    pub fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
94        loop {
95            match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
96                ConnectionState::Active(mut active) => match active.poll_new_outbound(cx) {
97                    Poll::Ready(Ok(stream)) => {
98                        self.inner = ConnectionState::Active(active);
99                        return Poll::Ready(Ok(stream));
100                    }
101                    Poll::Pending => {
102                        self.inner = ConnectionState::Active(active);
103                        return Poll::Pending;
104                    }
105                    Poll::Ready(Err(e)) => {
106                        self.inner = ConnectionState::Cleanup(active.cleanup(e));
107                        continue;
108                    }
109                },
110                ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx) {
111                    Poll::Ready(Ok(())) => {
112                        self.inner = ConnectionState::Closed;
113                        return Poll::Ready(Err(ConnectionError::Closed));
114                    }
115                    Poll::Ready(Err(e)) => {
116                        self.inner = ConnectionState::Closed;
117                        return Poll::Ready(Err(e));
118                    }
119                    Poll::Pending => {
120                        self.inner = ConnectionState::Closing(inner);
121                        return Poll::Pending;
122                    }
123                },
124                ConnectionState::Cleanup(mut inner) => match inner.poll_unpin(cx) {
125                    Poll::Ready(e) => {
126                        self.inner = ConnectionState::Closed;
127                        return Poll::Ready(Err(e));
128                    }
129                    Poll::Pending => {
130                        self.inner = ConnectionState::Cleanup(inner);
131                        return Poll::Pending;
132                    }
133                },
134                ConnectionState::Closed => {
135                    self.inner = ConnectionState::Closed;
136                    return Poll::Ready(Err(ConnectionError::Closed));
137                }
138                ConnectionState::Poisoned => unreachable!(),
139            }
140        }
141    }
142
143    /// Poll for the next inbound stream.
144    ///
145    /// If this function returns `None`, the underlying connection is closed.
146    pub fn poll_next_inbound(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Stream>>> {
147        loop {
148            match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
149                ConnectionState::Active(mut active) => match active.poll(cx) {
150                    Poll::Ready(Ok(stream)) => {
151                        self.inner = ConnectionState::Active(active);
152                        return Poll::Ready(Some(Ok(stream)));
153                    }
154                    Poll::Ready(Err(e)) => {
155                        self.inner = ConnectionState::Cleanup(active.cleanup(e));
156                        continue;
157                    }
158                    Poll::Pending => {
159                        self.inner = ConnectionState::Active(active);
160                        return Poll::Pending;
161                    }
162                },
163                ConnectionState::Closing(mut closing) => match closing.poll_unpin(cx) {
164                    Poll::Ready(Ok(())) => {
165                        self.inner = ConnectionState::Closed;
166                        return Poll::Ready(None);
167                    }
168                    Poll::Ready(Err(e)) => {
169                        self.inner = ConnectionState::Closed;
170                        return Poll::Ready(Some(Err(e)));
171                    }
172                    Poll::Pending => {
173                        self.inner = ConnectionState::Closing(closing);
174                        return Poll::Pending;
175                    }
176                },
177                ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) {
178                    Poll::Ready(ConnectionError::Closed) => {
179                        self.inner = ConnectionState::Closed;
180                        return Poll::Ready(None);
181                    }
182                    Poll::Ready(other) => {
183                        self.inner = ConnectionState::Closed;
184                        return Poll::Ready(Some(Err(other)));
185                    }
186                    Poll::Pending => {
187                        self.inner = ConnectionState::Cleanup(cleanup);
188                        return Poll::Pending;
189                    }
190                },
191                ConnectionState::Closed => {
192                    self.inner = ConnectionState::Closed;
193                    return Poll::Ready(None);
194                }
195                ConnectionState::Poisoned => unreachable!(),
196            }
197        }
198    }
199
200    /// Close the connection.
201    pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
202        loop {
203            match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) {
204                ConnectionState::Active(active) => {
205                    self.inner = ConnectionState::Closing(active.close());
206                }
207                ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx)? {
208                    Poll::Ready(()) => {
209                        self.inner = ConnectionState::Closed;
210                    }
211                    Poll::Pending => {
212                        self.inner = ConnectionState::Closing(inner);
213                        return Poll::Pending;
214                    }
215                },
216                ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) {
217                    Poll::Ready(reason) => {
218                        log::warn!("Failure while closing connection: {}", reason);
219                        self.inner = ConnectionState::Closed;
220                        return Poll::Ready(Ok(()));
221                    }
222                    Poll::Pending => {
223                        self.inner = ConnectionState::Cleanup(cleanup);
224                        return Poll::Pending;
225                    }
226                },
227                ConnectionState::Closed => {
228                    self.inner = ConnectionState::Closed;
229                    return Poll::Ready(Ok(()));
230                }
231                ConnectionState::Poisoned => {
232                    unreachable!()
233                }
234            }
235        }
236    }
237}
238
239impl<T> Drop for Connection<T> {
240    fn drop(&mut self) {
241        match &mut self.inner {
242            ConnectionState::Active(active) => active.drop_all_streams(),
243            ConnectionState::Closing(_) => {}
244            ConnectionState::Cleanup(_) => {}
245            ConnectionState::Closed => {}
246            ConnectionState::Poisoned => {}
247        }
248    }
249}
250
251enum ConnectionState<T> {
252    /// The connection is alive and healthy.
253    Active(Active<T>),
254    /// Our user requested to shutdown the connection, we are working on it.
255    Closing(Closing<T>),
256    /// An error occurred and we are cleaning up our resources.
257    Cleanup(Cleanup),
258    /// The connection is closed.
259    Closed,
260    /// Something went wrong during our state transitions. Should never happen unless there is a bug.
261    Poisoned,
262}
263
264impl<T> fmt::Debug for ConnectionState<T> {
265    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266        match self {
267            ConnectionState::Active(_) => write!(f, "Active"),
268            ConnectionState::Closing(_) => write!(f, "Closing"),
269            ConnectionState::Cleanup(_) => write!(f, "Cleanup"),
270            ConnectionState::Closed => write!(f, "Closed"),
271            ConnectionState::Poisoned => write!(f, "Poisoned"),
272        }
273    }
274}
275
276/// The active state of [`Connection`].
277struct Active<T> {
278    id: Id,
279    mode: Mode,
280    config: Arc<Config>,
281    socket: Fuse<frame::Io<T>>,
282    next_id: u32,
283
284    streams: IntMap<StreamId, Arc<Mutex<stream::Shared>>>,
285    stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
286    no_streams_waker: Option<Waker>,
287
288    pending_frames: VecDeque<Frame<()>>,
289    new_outbound_stream_waker: Option<Waker>,
290}
291
292/// `Stream` to `Connection` commands.
293#[derive(Debug)]
294pub(crate) enum StreamCommand {
295    /// A new frame should be sent to the remote.
296    SendFrame(Frame<Either<Data, WindowUpdate>>),
297    /// Close a stream.
298    CloseStream { ack: bool },
299}
300
301/// Possible actions as a result of incoming frame handling.
302#[derive(Debug)]
303enum Action {
304    /// Nothing to be done.
305    None,
306    /// A new stream has been opened by the remote.
307    New(Stream, Option<Frame<WindowUpdate>>),
308    /// A window update should be sent to the remote.
309    Update(Frame<WindowUpdate>),
310    /// A ping should be answered.
311    Ping(Frame<Ping>),
312    /// A stream should be reset.
313    Reset(Frame<Data>),
314    /// The connection should be terminated.
315    Terminate(Frame<GoAway>),
316}
317
318impl<T> fmt::Debug for Active<T> {
319    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
320        f.debug_struct("Connection")
321            .field("id", &self.id)
322            .field("mode", &self.mode)
323            .field("streams", &self.streams.len())
324            .field("next_id", &self.next_id)
325            .finish()
326    }
327}
328
329impl<T> fmt::Display for Active<T> {
330    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
331        write!(
332            f,
333            "(Connection {} {:?} (streams {}))",
334            self.id,
335            self.mode,
336            self.streams.len()
337        )
338    }
339}
340
341impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
342    /// Create a new `Connection` from the given I/O resource.
343    fn new(socket: T, cfg: Config, mode: Mode) -> Self {
344        let id = Id::random();
345        log::debug!("new connection: {} ({:?})", id, mode);
346        let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
347        Active {
348            id,
349            mode,
350            config: Arc::new(cfg),
351            socket,
352            streams: IntMap::default(),
353            stream_receivers: SelectAll::default(),
354            no_streams_waker: None,
355            next_id: match mode {
356                Mode::Client => 1,
357                Mode::Server => 2,
358            },
359            pending_frames: VecDeque::default(),
360            new_outbound_stream_waker: None,
361        }
362    }
363
364    /// Gracefully close the connection to the remote.
365    fn close(self) -> Closing<T> {
366        Closing::new(self.stream_receivers, self.pending_frames, self.socket)
367    }
368
369    /// Cleanup all our resources.
370    ///
371    /// This should be called in the context of an unrecoverable error on the connection.
372    fn cleanup(mut self, error: ConnectionError) -> Cleanup {
373        self.drop_all_streams();
374
375        Cleanup::new(self.stream_receivers, error)
376    }
377
378    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
379        loop {
380            if self.socket.poll_ready_unpin(cx).is_ready() {
381                if let Some(frame) = self.pending_frames.pop_front() {
382                    self.socket.start_send_unpin(frame)?;
383                    continue;
384                }
385            }
386
387            match self.socket.poll_flush_unpin(cx)? {
388                Poll::Ready(()) => {}
389                Poll::Pending => {}
390            }
391
392            match self.stream_receivers.poll_next_unpin(cx) {
393                Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
394                    self.on_send_frame(frame);
395                    continue;
396                }
397                Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
398                    self.on_close_stream(id, ack);
399                    continue;
400                }
401                Poll::Ready(Some((id, None))) => {
402                    self.on_drop_stream(id);
403                    continue;
404                }
405                Poll::Ready(None) => {
406                    self.no_streams_waker = Some(cx.waker().clone());
407                }
408                Poll::Pending => {}
409            }
410
411            match self.socket.poll_next_unpin(cx) {
412                Poll::Ready(Some(frame)) => {
413                    if let Some(stream) = self.on_frame(frame?)? {
414                        return Poll::Ready(Ok(stream));
415                    }
416                    continue;
417                }
418                Poll::Ready(None) => {
419                    return Poll::Ready(Err(ConnectionError::Closed));
420                }
421                Poll::Pending => {}
422            }
423
424            // If we make it this far, at least one of the above must have registered a waker.
425            return Poll::Pending;
426        }
427    }
428
429    fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
430        if self.streams.len() >= self.config.max_num_streams {
431            log::error!("{}: maximum number of streams reached", self.id);
432            return Poll::Ready(Err(ConnectionError::TooManyStreams));
433        }
434
435        if self.ack_backlog() >= MAX_ACK_BACKLOG {
436            log::debug!("{MAX_ACK_BACKLOG} streams waiting for ACK, registering task for wake-up until remote acknowledges at least one stream");
437            self.new_outbound_stream_waker = Some(cx.waker().clone());
438            return Poll::Pending;
439        }
440
441        log::trace!("{}: creating new outbound stream", self.id);
442
443        let id = self.next_stream_id()?;
444        let extra_credit = self.config.receive_window - DEFAULT_CREDIT;
445
446        if extra_credit > 0 {
447            let mut frame = Frame::window_update(id, extra_credit);
448            frame.header_mut().syn();
449            log::trace!("{}/{}: sending initial {}", self.id, id, frame.header());
450            self.pending_frames.push_back(frame.into());
451        }
452
453        let mut stream = self.make_new_outbound_stream(id, self.config.receive_window);
454
455        if extra_credit == 0 {
456            stream.set_flag(stream::Flag::Syn)
457        }
458
459        log::debug!("{}: new outbound {} of {}", self.id, stream, self);
460        self.streams.insert(id, stream.clone_shared());
461
462        Poll::Ready(Ok(stream))
463    }
464
465    fn on_send_frame(&mut self, frame: Frame<Either<Data, WindowUpdate>>) {
466        log::trace!(
467            "{}/{}: sending: {}",
468            self.id,
469            frame.header().stream_id(),
470            frame.header()
471        );
472        self.pending_frames.push_back(frame.into());
473    }
474
475    fn on_close_stream(&mut self, id: StreamId, ack: bool) {
476        log::trace!("{}/{}: sending close", self.id, id);
477        self.pending_frames
478            .push_back(Frame::close_stream(id, ack).into());
479    }
480
481    fn on_drop_stream(&mut self, stream_id: StreamId) {
482        let s = self.streams.remove(&stream_id).expect("stream not found");
483
484        log::trace!("{}: removing dropped stream {}", self.id, stream_id);
485        let frame = {
486            let mut shared = s.lock();
487            let frame = match shared.update_state(self.id, stream_id, State::Closed) {
488                // The stream was dropped without calling `poll_close`.
489                // We reset the stream to inform the remote of the closure.
490                State::Open { .. } => {
491                    let mut header = Header::data(stream_id, 0);
492                    header.rst();
493                    Some(Frame::new(header))
494                }
495                // The stream was dropped without calling `poll_close`.
496                // We have already received a FIN from remote and send one
497                // back which closes the stream for good.
498                State::RecvClosed => {
499                    let mut header = Header::data(stream_id, 0);
500                    header.fin();
501                    Some(Frame::new(header))
502                }
503                // The stream was properly closed. We already sent our FIN frame.
504                // The remote may be out of credit though and blocked on
505                // writing more data. We may need to reset the stream.
506                State::SendClosed => {
507                    if self.config.window_update_mode == WindowUpdateMode::OnRead
508                        && shared.window == 0
509                    {
510                        // The remote may be waiting for a window update
511                        // which we will never send, so reset the stream now.
512                        let mut header = Header::data(stream_id, 0);
513                        header.rst();
514                        Some(Frame::new(header))
515                    } else {
516                        // The remote has either still credit or will be given more
517                        // (due to an enqueued window update or because the update
518                        // mode is `OnReceive`) or we already have inbound frames in
519                        // the socket buffer which will be processed later. In any
520                        // case we will reply with an RST in `Connection::on_data`
521                        // because the stream will no longer be known.
522                        None
523                    }
524                }
525                // The stream was properly closed. We already have sent our FIN frame. The
526                // remote end has already done so in the past.
527                State::Closed => None,
528            };
529            if let Some(w) = shared.reader.take() {
530                w.wake()
531            }
532            if let Some(w) = shared.writer.take() {
533                w.wake()
534            }
535            frame
536        };
537        if let Some(f) = frame {
538            log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
539            self.pending_frames.push_back(f.into());
540        }
541    }
542
543    /// Process the result of reading from the socket.
544    ///
545    /// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed
546    /// and return a corresponding error, which terminates the connection.
547    /// Otherwise we process the frame and potentially return a new `Stream`
548    /// if one was opened by the remote.
549    fn on_frame(&mut self, frame: Frame<()>) -> Result<Option<Stream>> {
550        log::trace!("{}: received: {}", self.id, frame.header());
551
552        if frame.header().flags().contains(header::ACK) {
553            let id = frame.header().stream_id();
554            if let Some(stream) = self.streams.get(&id) {
555                stream
556                    .lock()
557                    .update_state(self.id, id, State::Open { acknowledged: true });
558            }
559            if let Some(waker) = self.new_outbound_stream_waker.take() {
560                waker.wake();
561            }
562        }
563
564        let action = match frame.header().tag() {
565            Tag::Data => self.on_data(frame.into_data()),
566            Tag::WindowUpdate => self.on_window_update(&frame.into_window_update()),
567            Tag::Ping => self.on_ping(&frame.into_ping()),
568            Tag::GoAway => return Err(ConnectionError::Closed),
569        };
570        match action {
571            Action::None => {}
572            Action::New(stream, update) => {
573                log::trace!("{}: new inbound {} of {}", self.id, stream, self);
574                if let Some(f) = update {
575                    log::trace!("{}/{}: sending update", self.id, f.header().stream_id());
576                    self.pending_frames.push_back(f.into());
577                }
578                return Ok(Some(stream));
579            }
580            Action::Update(f) => {
581                log::trace!("{}: sending update: {:?}", self.id, f.header());
582                self.pending_frames.push_back(f.into());
583            }
584            Action::Ping(f) => {
585                log::trace!("{}/{}: pong", self.id, f.header().stream_id());
586                self.pending_frames.push_back(f.into());
587            }
588            Action::Reset(f) => {
589                log::trace!("{}/{}: sending reset", self.id, f.header().stream_id());
590                self.pending_frames.push_back(f.into());
591            }
592            Action::Terminate(f) => {
593                log::trace!("{}: sending term", self.id);
594                self.pending_frames.push_back(f.into());
595            }
596        }
597
598        Ok(None)
599    }
600
601    fn on_data(&mut self, frame: Frame<Data>) -> Action {
602        let stream_id = frame.header().stream_id();
603
604        if frame.header().flags().contains(header::RST) {
605            // stream reset
606            if let Some(s) = self.streams.get_mut(&stream_id) {
607                let mut shared = s.lock();
608                shared.update_state(self.id, stream_id, State::Closed);
609                if let Some(w) = shared.reader.take() {
610                    w.wake()
611                }
612                if let Some(w) = shared.writer.take() {
613                    w.wake()
614                }
615            }
616            return Action::None;
617        }
618
619        let is_finish = frame.header().flags().contains(header::FIN); // half-close
620
621        if frame.header().flags().contains(header::SYN) {
622            // new stream
623            if !self.is_valid_remote_id(stream_id, Tag::Data) {
624                log::error!("{}: invalid stream id {}", self.id, stream_id);
625                return Action::Terminate(Frame::protocol_error());
626            }
627            if frame.body().len() > DEFAULT_CREDIT as usize {
628                log::error!(
629                    "{}/{}: 1st body of stream exceeds default credit",
630                    self.id,
631                    stream_id
632                );
633                return Action::Terminate(Frame::protocol_error());
634            }
635            if self.streams.contains_key(&stream_id) {
636                log::error!("{}/{}: stream already exists", self.id, stream_id);
637                return Action::Terminate(Frame::protocol_error());
638            }
639            if self.streams.len() == self.config.max_num_streams {
640                log::error!("{}: maximum number of streams reached", self.id);
641                return Action::Terminate(Frame::internal_error());
642            }
643            let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
644            let mut window_update = None;
645            {
646                let mut shared = stream.shared();
647                if is_finish {
648                    shared.update_state(self.id, stream_id, State::RecvClosed);
649                }
650                shared.window = shared.window.saturating_sub(frame.body_len());
651                shared.buffer.push(frame.into_body());
652
653                #[allow(deprecated)]
654                if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) {
655                    if let Some(credit) = shared.next_window_update() {
656                        shared.window += credit;
657                        let mut frame = Frame::window_update(stream_id, credit);
658                        frame.header_mut().ack();
659                        window_update = Some(frame)
660                    }
661                }
662            }
663            if window_update.is_none() {
664                stream.set_flag(stream::Flag::Ack)
665            }
666            self.streams.insert(stream_id, stream.clone_shared());
667            return Action::New(stream, window_update);
668        }
669
670        if let Some(s) = self.streams.get_mut(&stream_id) {
671            let mut shared = s.lock();
672            if frame.body().len() > shared.window as usize {
673                log::error!(
674                    "{}/{}: frame body larger than window of stream",
675                    self.id,
676                    stream_id
677                );
678                return Action::Terminate(Frame::protocol_error());
679            }
680            if is_finish {
681                shared.update_state(self.id, stream_id, State::RecvClosed);
682            }
683            let max_buffer_size = self.config.max_buffer_size;
684            if shared.buffer.len() >= max_buffer_size {
685                log::error!(
686                    "{}/{}: buffer of stream grows beyond limit",
687                    self.id,
688                    stream_id
689                );
690                let mut header = Header::data(stream_id, 0);
691                header.rst();
692                return Action::Reset(Frame::new(header));
693            }
694            shared.window = shared.window.saturating_sub(frame.body_len());
695            shared.buffer.push(frame.into_body());
696            if let Some(w) = shared.reader.take() {
697                w.wake()
698            }
699            #[allow(deprecated)]
700            if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) {
701                if let Some(credit) = shared.next_window_update() {
702                    shared.window += credit;
703                    let frame = Frame::window_update(stream_id, credit);
704                    return Action::Update(frame);
705                }
706            }
707        } else {
708            log::trace!(
709                "{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}",
710                self.id,
711                stream_id,
712                frame
713            );
714            // We do not consider this a protocol violation and thus do not send a stream reset
715            // because we may still be processing pending `StreamCommand`s of this stream that were
716            // sent before it has been dropped and "garbage collected". Such a stream reset would
717            // interfere with the frames that still need to be sent, causing premature stream
718            // termination for the remote.
719            //
720            // See https://github.com/paritytech/yamux/issues/110 for details.
721        }
722
723        Action::None
724    }
725
726    fn on_window_update(&mut self, frame: &Frame<WindowUpdate>) -> Action {
727        let stream_id = frame.header().stream_id();
728
729        if frame.header().flags().contains(header::RST) {
730            // stream reset
731            if let Some(s) = self.streams.get_mut(&stream_id) {
732                let mut shared = s.lock();
733                shared.update_state(self.id, stream_id, State::Closed);
734                if let Some(w) = shared.reader.take() {
735                    w.wake()
736                }
737                if let Some(w) = shared.writer.take() {
738                    w.wake()
739                }
740            }
741            return Action::None;
742        }
743
744        let is_finish = frame.header().flags().contains(header::FIN); // half-close
745
746        if frame.header().flags().contains(header::SYN) {
747            // new stream
748            if !self.is_valid_remote_id(stream_id, Tag::WindowUpdate) {
749                log::error!("{}: invalid stream id {}", self.id, stream_id);
750                return Action::Terminate(Frame::protocol_error());
751            }
752            if self.streams.contains_key(&stream_id) {
753                log::error!("{}/{}: stream already exists", self.id, stream_id);
754                return Action::Terminate(Frame::protocol_error());
755            }
756            if self.streams.len() == self.config.max_num_streams {
757                log::error!("{}: maximum number of streams reached", self.id);
758                return Action::Terminate(Frame::protocol_error());
759            }
760
761            let credit = frame.header().credit() + DEFAULT_CREDIT;
762            let mut stream = self.make_new_inbound_stream(stream_id, credit);
763            stream.set_flag(stream::Flag::Ack);
764
765            if is_finish {
766                stream
767                    .shared()
768                    .update_state(self.id, stream_id, State::RecvClosed);
769            }
770            self.streams.insert(stream_id, stream.clone_shared());
771            return Action::New(stream, None);
772        }
773
774        if let Some(s) = self.streams.get_mut(&stream_id) {
775            let mut shared = s.lock();
776            shared.credit += frame.header().credit();
777            if is_finish {
778                shared.update_state(self.id, stream_id, State::RecvClosed);
779            }
780            if let Some(w) = shared.writer.take() {
781                w.wake()
782            }
783        } else {
784            log::trace!(
785                "{}/{}: window update for unknown stream, possibly dropped earlier: {:?}",
786                self.id,
787                stream_id,
788                frame
789            );
790            // We do not consider this a protocol violation and thus do not send a stream reset
791            // because we may still be processing pending `StreamCommand`s of this stream that were
792            // sent before it has been dropped and "garbage collected". Such a stream reset would
793            // interfere with the frames that still need to be sent, causing premature stream
794            // termination for the remote.
795            //
796            // See https://github.com/paritytech/yamux/issues/110 for details.
797        }
798
799        Action::None
800    }
801
802    fn on_ping(&mut self, frame: &Frame<Ping>) -> Action {
803        let stream_id = frame.header().stream_id();
804        if frame.header().flags().contains(header::ACK) {
805            // pong
806            return Action::None;
807        }
808        if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) {
809            let mut hdr = Header::ping(frame.header().nonce());
810            hdr.ack();
811            return Action::Ping(Frame::new(hdr));
812        }
813        log::trace!(
814            "{}/{}: ping for unknown stream, possibly dropped earlier: {:?}",
815            self.id,
816            stream_id,
817            frame
818        );
819        // We do not consider this a protocol violation and thus do not send a stream reset because
820        // we may still be processing pending `StreamCommand`s of this stream that were sent before
821        // it has been dropped and "garbage collected". Such a stream reset would interfere with the
822        // frames that still need to be sent, causing premature stream termination for the remote.
823        //
824        // See https://github.com/paritytech/yamux/issues/110 for details.
825
826        Action::None
827    }
828
829    fn make_new_inbound_stream(&mut self, id: StreamId, credit: u32) -> Stream {
830        let config = self.config.clone();
831
832        let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
833        self.stream_receivers.push(TaggedStream::new(id, receiver));
834        if let Some(waker) = self.no_streams_waker.take() {
835            waker.wake();
836        }
837
838        Stream::new_inbound(id, self.id, config, credit, sender)
839    }
840
841    fn make_new_outbound_stream(&mut self, id: StreamId, window: u32) -> Stream {
842        let config = self.config.clone();
843
844        let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
845        self.stream_receivers.push(TaggedStream::new(id, receiver));
846        if let Some(waker) = self.no_streams_waker.take() {
847            waker.wake();
848        }
849
850        Stream::new_outbound(id, self.id, config, window, sender)
851    }
852
853    fn next_stream_id(&mut self) -> Result<StreamId> {
854        let proposed = StreamId::new(self.next_id);
855        self.next_id = self
856            .next_id
857            .checked_add(2)
858            .ok_or(ConnectionError::NoMoreStreamIds)?;
859        match self.mode {
860            Mode::Client => assert!(proposed.is_client()),
861            Mode::Server => assert!(proposed.is_server()),
862        }
863        Ok(proposed)
864    }
865
866    /// The ACK backlog is defined as the number of outbound streams that have not yet been acknowledged.
867    fn ack_backlog(&mut self) -> usize {
868        self.streams
869            .iter()
870            // Whether this is an outbound stream.
871            //
872            // Clients use odd IDs and servers use even IDs.
873            // A stream is outbound if:
874            //
875            // - Its ID is odd and we are the client.
876            // - Its ID is even and we are the server.
877            .filter(|(id, _)| match self.mode {
878                Mode::Client => id.is_client(),
879                Mode::Server => id.is_server(),
880            })
881            .filter(|(_, s)| s.lock().is_pending_ack())
882            .count()
883    }
884
885    // Check if the given stream ID is valid w.r.t. the provided tag and our connection mode.
886    fn is_valid_remote_id(&self, id: StreamId, tag: Tag) -> bool {
887        if tag == Tag::Ping || tag == Tag::GoAway {
888            return id.is_session();
889        }
890        match self.mode {
891            Mode::Client => id.is_server(),
892            Mode::Server => id.is_client(),
893        }
894    }
895}
896
897impl<T> Active<T> {
898    /// Close and drop all `Stream`s and wake any pending `Waker`s.
899    fn drop_all_streams(&mut self) {
900        for (id, s) in self.streams.drain() {
901            let mut shared = s.lock();
902            shared.update_state(self.id, id, State::Closed);
903            if let Some(w) = shared.reader.take() {
904                w.wake()
905            }
906            if let Some(w) = shared.writer.take() {
907                w.wake()
908            }
909        }
910    }
911}