1use crate::connection::rtt::Rtt;
12use crate::frame::header::ACK;
13use crate::ConnectionError;
14use crate::{
15 chunks::Chunks,
16 connection::{self, rtt, StreamCommand},
17 frame::{
18 header::{Data, Header, StreamId, WindowUpdate},
19 Frame,
20 },
21 Config, DEFAULT_CREDIT,
22};
23use flow_control::FlowController;
24use futures::{
25 channel::mpsc,
26 future::Either,
27 io::{AsyncRead, AsyncWrite},
28 ready, SinkExt,
29};
30use parking_lot::{Mutex, MutexGuard};
31use std::{
32 fmt, io,
33 pin::Pin,
34 sync::Arc,
35 task::{Context, Poll, Waker},
36};
37
38mod flow_control;
39
40#[derive(Copy, Clone, Debug, PartialEq, Eq)]
42pub enum State {
43 Open {
45 acknowledged: bool,
55 },
56 SendClosed,
58 RecvClosed,
60 Closed,
62}
63
64impl State {
65 pub fn can_read(self) -> bool {
67 !matches!(self, State::RecvClosed | State::Closed)
68 }
69
70 pub fn can_write(self) -> bool {
72 !matches!(self, State::SendClosed | State::Closed)
73 }
74}
75
76#[derive(Copy, Clone, Debug, PartialEq, Eq)]
78pub(crate) enum Flag {
79 None,
81 Syn,
83 Ack,
85}
86
87pub struct Stream {
95 id: StreamId,
96 conn: connection::Id,
97 config: Arc<Config>,
98 sender: mpsc::Sender<StreamCommand>,
99 flag: Flag,
100 shared: Arc<Mutex<Shared>>,
101}
102
103impl fmt::Debug for Stream {
104 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105 f.debug_struct("Stream")
106 .field("id", &self.id.val())
107 .field("connection", &self.conn)
108 .finish()
109 }
110}
111
112impl fmt::Display for Stream {
113 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114 write!(f, "(Stream {}/{})", self.conn, self.id.val())
115 }
116}
117
118impl Stream {
119 pub(crate) fn new_inbound(
120 id: StreamId,
121 conn: connection::Id,
122 config: Arc<Config>,
123 send_window: u32,
124 sender: mpsc::Sender<StreamCommand>,
125 rtt: rtt::Rtt,
126 accumulated_max_stream_windows: Arc<Mutex<usize>>,
127 ) -> Self {
128 Self {
129 id,
130 conn,
131 config: config.clone(),
132 sender,
133 flag: Flag::Ack,
134 shared: Arc::new(Mutex::new(Shared::new(
135 DEFAULT_CREDIT,
136 send_window,
137 accumulated_max_stream_windows,
138 rtt,
139 config,
140 ))),
141 }
142 }
143
144 pub(crate) fn new_outbound(
145 id: StreamId,
146 conn: connection::Id,
147 config: Arc<Config>,
148 sender: mpsc::Sender<StreamCommand>,
149 rtt: rtt::Rtt,
150 accumulated_max_stream_windows: Arc<Mutex<usize>>,
151 ) -> Self {
152 Self {
153 id,
154 conn,
155 config: config.clone(),
156 sender,
157 flag: Flag::Syn,
158 shared: Arc::new(Mutex::new(Shared::new(
159 DEFAULT_CREDIT,
160 DEFAULT_CREDIT,
161 accumulated_max_stream_windows,
162 rtt,
163 config,
164 ))),
165 }
166 }
167
168 pub fn id(&self) -> StreamId {
170 self.id
171 }
172
173 pub fn is_write_closed(&self) -> bool {
174 matches!(self.shared().state(), State::SendClosed)
175 }
176
177 pub fn is_closed(&self) -> bool {
178 matches!(self.shared().state(), State::Closed)
179 }
180
181 pub fn is_pending_ack(&self) -> bool {
183 self.shared().is_pending_ack()
184 }
185
186 pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> {
187 self.shared.lock()
188 }
189
190 pub(crate) fn clone_shared(&self) -> Arc<Mutex<Shared>> {
191 self.shared.clone()
192 }
193
194 fn write_zero_err(&self) -> io::Error {
195 let msg = format!("{}/{}: connection is closed", self.conn, self.id);
196 io::Error::new(io::ErrorKind::WriteZero, msg)
197 }
198
199 fn add_flag(&mut self, header: &mut Header<Either<Data, WindowUpdate>>) {
201 match self.flag {
202 Flag::None => (),
203 Flag::Syn => {
204 header.syn();
205 self.flag = Flag::None
206 }
207 Flag::Ack => {
208 header.ack();
209 self.flag = Flag::None
210 }
211 }
212 }
213
214 fn send_window_update(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
217 if !self.shared.lock().state.can_read() {
218 return Poll::Ready(Ok(()));
219 }
220
221 ready!(self
222 .sender
223 .poll_ready(cx)
224 .map_err(|_| self.write_zero_err())?);
225
226 let Some(credit) = self.shared.lock().next_window_update() else {
227 return Poll::Ready(Ok(()));
228 };
229
230 let mut frame = Frame::window_update(self.id, credit).right();
231 self.add_flag(frame.header_mut());
232 let cmd = StreamCommand::SendFrame(frame);
233 self.sender
234 .start_send(cmd)
235 .map_err(|_| self.write_zero_err())?;
236
237 Poll::Ready(Ok(()))
238 }
239}
240
241#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
243pub struct Packet(Vec<u8>);
244
245impl AsRef<[u8]> for Packet {
246 fn as_ref(&self) -> &[u8] {
247 self.0.as_ref()
248 }
249}
250
251impl futures::stream::Stream for Stream {
252 type Item = io::Result<Packet>;
253
254 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
255 if !self.config.read_after_close && self.sender.is_closed() {
256 return Poll::Ready(None);
257 }
258
259 match self.send_window_update(cx) {
260 Poll::Ready(Ok(())) => {}
261 Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
262 Poll::Pending => {}
264 }
265
266 let mut shared = self.shared();
267
268 if let Some(bytes) = shared.buffer.pop() {
269 let off = bytes.offset();
270 let mut vec = bytes.into_vec();
271 if off != 0 {
272 log::debug!(
277 "{}/{}: chunk has been partially consumed",
278 self.conn,
279 self.id
280 );
281 vec = vec.split_off(off)
282 }
283 return Poll::Ready(Some(Ok(Packet(vec))));
284 }
285
286 if !shared.state().can_read() {
288 log::debug!("{}/{}: eof", self.conn, self.id);
289 return Poll::Ready(None); }
291
292 shared.reader = Some(cx.waker().clone());
295
296 Poll::Pending
297 }
298}
299
300impl AsyncRead for Stream {
303 fn poll_read(
304 mut self: Pin<&mut Self>,
305 cx: &mut Context,
306 buf: &mut [u8],
307 ) -> Poll<io::Result<usize>> {
308 if !self.config.read_after_close && self.sender.is_closed() {
309 return Poll::Ready(Ok(0));
310 }
311
312 match self.send_window_update(cx) {
313 Poll::Ready(Ok(())) => {}
314 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
315 Poll::Pending => {}
317 }
318
319 let mut shared = self.shared();
321 let mut n = 0;
322 while let Some(chunk) = shared.buffer.front_mut() {
323 if chunk.is_empty() {
324 shared.buffer.pop();
325 continue;
326 }
327 let k = std::cmp::min(chunk.len(), buf.len() - n);
328 buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
329 n += k;
330 chunk.advance(k);
331 if n == buf.len() {
332 break;
333 }
334 }
335
336 if n > 0 {
337 log::trace!("{}/{}: read {} bytes", self.conn, self.id, n);
338 return Poll::Ready(Ok(n));
339 }
340
341 if !shared.state().can_read() {
343 log::debug!("{}/{}: eof", self.conn, self.id);
344 return Poll::Ready(Ok(0)); }
346
347 shared.reader = Some(cx.waker().clone());
350
351 Poll::Pending
352 }
353}
354
355impl AsyncWrite for Stream {
356 fn poll_write(
357 mut self: Pin<&mut Self>,
358 cx: &mut Context,
359 buf: &[u8],
360 ) -> Poll<io::Result<usize>> {
361 ready!(self
362 .sender
363 .poll_ready(cx)
364 .map_err(|_| self.write_zero_err())?);
365 let body = {
366 let mut shared = self.shared();
367 if !shared.state().can_write() {
368 log::debug!("{}/{}: can no longer write", self.conn, self.id);
369 return Poll::Ready(Err(self.write_zero_err()));
370 }
371 if shared.send_window() == 0 {
372 log::trace!("{}/{}: no more credit left", self.conn, self.id);
373 shared.writer = Some(cx.waker().clone());
374 return Poll::Pending;
375 }
376 let k = std::cmp::min(shared.send_window() as usize, buf.len());
377 let k = std::cmp::min(k, self.config.split_send_size);
378 shared
379 .consume_send_window(k as u32)
380 .expect("not exceed receive window");
381 Vec::from(&buf[..k])
382 };
383 let n = body.len();
384 let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left();
385 self.add_flag(frame.header_mut());
386 log::trace!("{}/{}: write {} bytes", self.conn, self.id, n);
387
388 if frame.header().flags().contains(ACK) {
393 self.shared()
394 .update_state(self.conn, self.id, State::Open { acknowledged: true });
395 }
396
397 let cmd = StreamCommand::SendFrame(frame);
398 self.sender
399 .start_send(cmd)
400 .map_err(|_| self.write_zero_err())?;
401 Poll::Ready(Ok(n))
402 }
403
404 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
405 self.sender
406 .poll_flush_unpin(cx)
407 .map_err(|_| self.write_zero_err())
408 }
409
410 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
411 if self.is_closed() {
412 return Poll::Ready(Ok(()));
413 }
414 ready!(self
415 .sender
416 .poll_ready(cx)
417 .map_err(|_| self.write_zero_err())?);
418 let ack = if self.flag == Flag::Ack {
419 self.flag = Flag::None;
420 true
421 } else {
422 false
423 };
424 log::trace!("{}/{}: close", self.conn, self.id);
425 let cmd = StreamCommand::CloseStream { ack };
426 self.sender
427 .start_send(cmd)
428 .map_err(|_| self.write_zero_err())?;
429 self.shared()
430 .update_state(self.conn, self.id, State::SendClosed);
431 Poll::Ready(Ok(()))
432 }
433}
434
435#[derive(Debug)]
436pub(crate) struct Shared {
437 state: State,
438 flow_controller: FlowController,
439 pub(crate) buffer: Chunks,
440 pub(crate) reader: Option<Waker>,
441 pub(crate) writer: Option<Waker>,
442}
443
444impl Shared {
445 fn new(
446 receive_window: u32,
447 send_window: u32,
448 accumulated_max_stream_windows: Arc<Mutex<usize>>,
449 rtt: Rtt,
450 config: Arc<Config>,
451 ) -> Self {
452 Shared {
453 state: State::Open {
454 acknowledged: false,
455 },
456 flow_controller: FlowController::new(
457 receive_window,
458 send_window,
459 accumulated_max_stream_windows,
460 rtt,
461 config,
462 ),
463 buffer: Chunks::new(),
464 reader: None,
465 writer: None,
466 }
467 }
468
469 pub(crate) fn state(&self) -> State {
470 self.state
471 }
472
473 pub(crate) fn update_state(
475 &mut self,
476 cid: connection::Id,
477 sid: StreamId,
478 next: State,
479 ) -> State {
480 use self::State::*;
481
482 let current = self.state;
483
484 match (current, next) {
485 (Closed, _) => {}
486 (Open { .. }, _) => self.state = next,
487 (RecvClosed, Closed) => self.state = Closed,
488 (RecvClosed, Open { .. }) => {}
489 (RecvClosed, RecvClosed) => {}
490 (RecvClosed, SendClosed) => self.state = Closed,
491 (SendClosed, Closed) => self.state = Closed,
492 (SendClosed, Open { .. }) => {}
493 (SendClosed, RecvClosed) => self.state = Closed,
494 (SendClosed, SendClosed) => {}
495 }
496
497 log::trace!(
498 "{}/{}: update state: (from {:?} to {:?} -> {:?})",
499 cid,
500 sid,
501 current,
502 next,
503 self.state
504 );
505
506 current }
508
509 pub(crate) fn next_window_update(&mut self) -> Option<u32> {
510 self.flow_controller.next_window_update(self.buffer.len())
511 }
512
513 pub fn is_pending_ack(&self) -> bool {
515 matches!(
516 self.state(),
517 State::Open {
518 acknowledged: false
519 }
520 )
521 }
522
523 pub(crate) fn send_window(&self) -> u32 {
524 self.flow_controller.send_window()
525 }
526
527 pub(crate) fn consume_send_window(&mut self, i: u32) -> Result<(), ConnectionError> {
528 self.flow_controller.consume_send_window(i)
529 }
530
531 pub(crate) fn increase_send_window_by(&mut self, i: u32) -> Result<(), ConnectionError> {
532 self.flow_controller.increase_send_window_by(i)
533 }
534
535 pub(crate) fn consume_receive_window(&mut self, i: u32) -> Result<(), ConnectionError> {
536 self.flow_controller.consume_receive_window(i)
537 }
538}