1use crate::connection::rtt::Rtt;
12use crate::frame::header::ACK;
13use crate::{
14 chunks::Chunks,
15 connection::{self, rtt, StreamCommand},
16 frame::{
17 header::{Data, Header, StreamId, WindowUpdate},
18 Frame,
19 },
20 Config, DEFAULT_CREDIT,
21};
22use flow_control::FlowController;
23use futures::{
24 channel::mpsc,
25 future::Either,
26 io::{AsyncRead, AsyncWrite},
27 ready, SinkExt,
28};
29use parking_lot::{Mutex, MutexGuard};
30use std::{
31 fmt, io,
32 pin::Pin,
33 sync::Arc,
34 task::{Context, Poll, Waker},
35};
36
37mod flow_control;
38
39#[derive(Copy, Clone, Debug, PartialEq, Eq)]
41pub enum State {
42 Open {
44 acknowledged: bool,
54 },
55 SendClosed,
57 RecvClosed,
59 Closed,
61}
62
63impl State {
64 pub fn can_read(self) -> bool {
66 !matches!(self, State::RecvClosed | State::Closed)
67 }
68
69 pub fn can_write(self) -> bool {
71 !matches!(self, State::SendClosed | State::Closed)
72 }
73}
74
75#[derive(Copy, Clone, Debug, PartialEq, Eq)]
77pub(crate) enum Flag {
78 None,
80 Syn,
82 Ack,
84}
85
86pub struct Stream {
94 id: StreamId,
95 conn: connection::Id,
96 config: Arc<Config>,
97 sender: mpsc::Sender<StreamCommand>,
98 flag: Flag,
99 shared: Arc<Mutex<Shared>>,
100}
101
102impl fmt::Debug for Stream {
103 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104 f.debug_struct("Stream")
105 .field("id", &self.id.val())
106 .field("connection", &self.conn)
107 .finish()
108 }
109}
110
111impl fmt::Display for Stream {
112 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113 write!(f, "(Stream {}/{})", self.conn, self.id.val())
114 }
115}
116
117impl Stream {
118 pub(crate) fn new_inbound(
119 id: StreamId,
120 conn: connection::Id,
121 config: Arc<Config>,
122 send_window: u32,
123 sender: mpsc::Sender<StreamCommand>,
124 rtt: rtt::Rtt,
125 accumulated_max_stream_windows: Arc<Mutex<usize>>,
126 ) -> Self {
127 Self {
128 id,
129 conn,
130 config: config.clone(),
131 sender,
132 flag: Flag::Ack,
133 shared: Arc::new(Mutex::new(Shared::new(
134 DEFAULT_CREDIT,
135 send_window,
136 accumulated_max_stream_windows,
137 rtt,
138 config,
139 ))),
140 }
141 }
142
143 pub(crate) fn new_outbound(
144 id: StreamId,
145 conn: connection::Id,
146 config: Arc<Config>,
147 sender: mpsc::Sender<StreamCommand>,
148 rtt: rtt::Rtt,
149 accumulated_max_stream_windows: Arc<Mutex<usize>>,
150 ) -> Self {
151 Self {
152 id,
153 conn,
154 config: config.clone(),
155 sender,
156 flag: Flag::Syn,
157 shared: Arc::new(Mutex::new(Shared::new(
158 DEFAULT_CREDIT,
159 DEFAULT_CREDIT,
160 accumulated_max_stream_windows,
161 rtt,
162 config,
163 ))),
164 }
165 }
166
167 pub fn id(&self) -> StreamId {
169 self.id
170 }
171
172 pub fn is_write_closed(&self) -> bool {
173 matches!(self.shared().state(), State::SendClosed)
174 }
175
176 pub fn is_closed(&self) -> bool {
177 matches!(self.shared().state(), State::Closed)
178 }
179
180 pub fn is_pending_ack(&self) -> bool {
182 self.shared().is_pending_ack()
183 }
184
185 pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> {
186 self.shared.lock()
187 }
188
189 pub(crate) fn clone_shared(&self) -> Arc<Mutex<Shared>> {
190 self.shared.clone()
191 }
192
193 fn write_zero_err(&self) -> io::Error {
194 let msg = format!("{}/{}: connection is closed", self.conn, self.id);
195 io::Error::new(io::ErrorKind::WriteZero, msg)
196 }
197
198 fn add_flag(&mut self, header: &mut Header<Either<Data, WindowUpdate>>) {
200 match self.flag {
201 Flag::None => (),
202 Flag::Syn => {
203 header.syn();
204 self.flag = Flag::None
205 }
206 Flag::Ack => {
207 header.ack();
208 self.flag = Flag::None
209 }
210 }
211 }
212
213 fn send_window_update(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
216 if !self.shared.lock().state.can_read() {
217 return Poll::Ready(Ok(()));
218 }
219
220 ready!(self
221 .sender
222 .poll_ready(cx)
223 .map_err(|_| self.write_zero_err())?);
224
225 let Some(credit) = self.shared.lock().next_window_update() else {
226 return Poll::Ready(Ok(()));
227 };
228
229 let mut frame = Frame::window_update(self.id, credit).right();
230 self.add_flag(frame.header_mut());
231 let cmd = StreamCommand::SendFrame(frame);
232 self.sender
233 .start_send(cmd)
234 .map_err(|_| self.write_zero_err())?;
235
236 Poll::Ready(Ok(()))
237 }
238}
239
240#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
242pub struct Packet(Vec<u8>);
243
244impl AsRef<[u8]> for Packet {
245 fn as_ref(&self) -> &[u8] {
246 self.0.as_ref()
247 }
248}
249
250impl futures::stream::Stream for Stream {
251 type Item = io::Result<Packet>;
252
253 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
254 if !self.config.read_after_close && self.sender.is_closed() {
255 return Poll::Ready(None);
256 }
257
258 match self.send_window_update(cx) {
259 Poll::Ready(Ok(())) => {}
260 Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
261 Poll::Pending => {}
263 }
264
265 let mut shared = self.shared();
266
267 if let Some(bytes) = shared.buffer.pop() {
268 let off = bytes.offset();
269 let mut vec = bytes.into_vec();
270 if off != 0 {
271 log::debug!(
276 "{}/{}: chunk has been partially consumed",
277 self.conn,
278 self.id
279 );
280 vec = vec.split_off(off)
281 }
282 return Poll::Ready(Some(Ok(Packet(vec))));
283 }
284
285 if !shared.state().can_read() {
287 log::debug!("{}/{}: eof", self.conn, self.id);
288 return Poll::Ready(None); }
290
291 shared.reader = Some(cx.waker().clone());
294
295 Poll::Pending
296 }
297}
298
299impl AsyncRead for Stream {
302 fn poll_read(
303 mut self: Pin<&mut Self>,
304 cx: &mut Context,
305 buf: &mut [u8],
306 ) -> Poll<io::Result<usize>> {
307 if !self.config.read_after_close && self.sender.is_closed() {
308 return Poll::Ready(Ok(0));
309 }
310
311 match self.send_window_update(cx) {
312 Poll::Ready(Ok(())) => {}
313 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
314 Poll::Pending => {}
316 }
317
318 let mut shared = self.shared();
320 let mut n = 0;
321 while let Some(chunk) = shared.buffer.front_mut() {
322 if chunk.is_empty() {
323 shared.buffer.pop();
324 continue;
325 }
326 let k = std::cmp::min(chunk.len(), buf.len() - n);
327 buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
328 n += k;
329 chunk.advance(k);
330 if n == buf.len() {
331 break;
332 }
333 }
334
335 if n > 0 {
336 log::trace!("{}/{}: read {} bytes", self.conn, self.id, n);
337 return Poll::Ready(Ok(n));
338 }
339
340 if !shared.state().can_read() {
342 log::debug!("{}/{}: eof", self.conn, self.id);
343 return Poll::Ready(Ok(0)); }
345
346 shared.reader = Some(cx.waker().clone());
349
350 Poll::Pending
351 }
352}
353
354impl AsyncWrite for Stream {
355 fn poll_write(
356 mut self: Pin<&mut Self>,
357 cx: &mut Context,
358 buf: &[u8],
359 ) -> Poll<io::Result<usize>> {
360 ready!(self
361 .sender
362 .poll_ready(cx)
363 .map_err(|_| self.write_zero_err())?);
364 let body = {
365 let mut shared = self.shared();
366 if !shared.state().can_write() {
367 log::debug!("{}/{}: can no longer write", self.conn, self.id);
368 return Poll::Ready(Err(self.write_zero_err()));
369 }
370 if shared.send_window() == 0 {
371 log::trace!("{}/{}: no more credit left", self.conn, self.id);
372 shared.writer = Some(cx.waker().clone());
373 return Poll::Pending;
374 }
375 let k = std::cmp::min(
376 shared.send_window(),
377 buf.len().try_into().unwrap_or(u32::MAX),
378 );
379 let k = std::cmp::min(
380 k,
381 self.config.split_send_size.try_into().unwrap_or(u32::MAX),
382 );
383 shared.consume_send_window(k);
384 Vec::from(&buf[..k as usize])
385 };
386 let n = body.len();
387 let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left();
388 self.add_flag(frame.header_mut());
389 log::trace!("{}/{}: write {} bytes", self.conn, self.id, n);
390
391 if frame.header().flags().contains(ACK) {
396 self.shared()
397 .update_state(self.conn, self.id, State::Open { acknowledged: true });
398 }
399
400 let cmd = StreamCommand::SendFrame(frame);
401 self.sender
402 .start_send(cmd)
403 .map_err(|_| self.write_zero_err())?;
404 Poll::Ready(Ok(n))
405 }
406
407 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
408 self.sender
409 .poll_flush_unpin(cx)
410 .map_err(|_| self.write_zero_err())
411 }
412
413 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
414 if self.is_closed() {
415 return Poll::Ready(Ok(()));
416 }
417 ready!(self
418 .sender
419 .poll_ready(cx)
420 .map_err(|_| self.write_zero_err())?);
421 let ack = if self.flag == Flag::Ack {
422 self.flag = Flag::None;
423 true
424 } else {
425 false
426 };
427 log::trace!("{}/{}: close", self.conn, self.id);
428 let cmd = StreamCommand::CloseStream { ack };
429 self.sender
430 .start_send(cmd)
431 .map_err(|_| self.write_zero_err())?;
432 self.shared()
433 .update_state(self.conn, self.id, State::SendClosed);
434 Poll::Ready(Ok(()))
435 }
436}
437
438#[derive(Debug)]
439pub(crate) struct Shared {
440 state: State,
441 flow_controller: FlowController,
442 pub(crate) buffer: Chunks,
443 pub(crate) reader: Option<Waker>,
444 pub(crate) writer: Option<Waker>,
445}
446
447impl Shared {
448 fn new(
449 receive_window: u32,
450 send_window: u32,
451 accumulated_max_stream_windows: Arc<Mutex<usize>>,
452 rtt: Rtt,
453 config: Arc<Config>,
454 ) -> Self {
455 Shared {
456 state: State::Open {
457 acknowledged: false,
458 },
459 flow_controller: FlowController::new(
460 receive_window,
461 send_window,
462 accumulated_max_stream_windows,
463 rtt,
464 config,
465 ),
466 buffer: Chunks::new(),
467 reader: None,
468 writer: None,
469 }
470 }
471
472 pub(crate) fn state(&self) -> State {
473 self.state
474 }
475
476 pub(crate) fn update_state(
478 &mut self,
479 cid: connection::Id,
480 sid: StreamId,
481 next: State,
482 ) -> State {
483 use self::State::*;
484
485 let current = self.state;
486
487 match (current, next) {
488 (Closed, _) => {}
489 (Open { .. }, _) => self.state = next,
490 (RecvClosed, Closed) => self.state = Closed,
491 (RecvClosed, Open { .. }) => {}
492 (RecvClosed, RecvClosed) => {}
493 (RecvClosed, SendClosed) => self.state = Closed,
494 (SendClosed, Closed) => self.state = Closed,
495 (SendClosed, Open { .. }) => {}
496 (SendClosed, RecvClosed) => self.state = Closed,
497 (SendClosed, SendClosed) => {}
498 }
499
500 log::trace!(
501 "{}/{}: update state: (from {:?} to {:?} -> {:?})",
502 cid,
503 sid,
504 current,
505 next,
506 self.state
507 );
508
509 current }
511
512 pub(crate) fn next_window_update(&mut self) -> Option<u32> {
513 self.flow_controller.next_window_update(self.buffer.len())
514 }
515
516 pub fn is_pending_ack(&self) -> bool {
518 matches!(
519 self.state(),
520 State::Open {
521 acknowledged: false
522 }
523 )
524 }
525
526 pub(crate) fn send_window(&self) -> u32 {
527 self.flow_controller.send_window()
528 }
529
530 pub(crate) fn consume_send_window(&mut self, i: u32) {
531 self.flow_controller.consume_send_window(i)
532 }
533
534 pub(crate) fn increase_send_window_by(&mut self, i: u32) {
535 self.flow_controller.increase_send_window_by(i)
536 }
537
538 pub(crate) fn receive_window(&self) -> u32 {
539 self.flow_controller.receive_window()
540 }
541
542 pub(crate) fn consume_receive_window(&mut self, i: u32) {
543 self.flow_controller.consume_receive_window(i)
544 }
545}