1use crate::frame::header::ACK;
12use crate::{
13 chunks::Chunks,
14 connection::{self, StreamCommand},
15 frame::{
16 header::{Data, Header, StreamId, WindowUpdate},
17 Frame,
18 },
19 Config, WindowUpdateMode, DEFAULT_CREDIT,
20};
21use futures::{
22 channel::mpsc,
23 future::Either,
24 io::{AsyncRead, AsyncWrite},
25 ready, SinkExt,
26};
27use parking_lot::{Mutex, MutexGuard};
28use std::convert::TryInto;
29use std::{
30 fmt, io,
31 pin::Pin,
32 sync::Arc,
33 task::{Context, Poll, Waker},
34};
35
36#[derive(Copy, Clone, Debug, PartialEq, Eq)]
38pub enum State {
39 Open {
41 acknowledged: bool,
51 },
52 SendClosed,
54 RecvClosed,
56 Closed,
58}
59
60impl State {
61 pub fn can_read(self) -> bool {
63 !matches!(self, State::RecvClosed | State::Closed)
64 }
65
66 pub fn can_write(self) -> bool {
68 !matches!(self, State::SendClosed | State::Closed)
69 }
70}
71
72#[derive(Copy, Clone, Debug, PartialEq, Eq)]
74pub(crate) enum Flag {
75 None,
77 Syn,
79 Ack,
81}
82
83pub struct Stream {
91 id: StreamId,
92 conn: connection::Id,
93 config: Arc<Config>,
94 sender: mpsc::Sender<StreamCommand>,
95 flag: Flag,
96 shared: Arc<Mutex<Shared>>,
97}
98
99impl fmt::Debug for Stream {
100 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
101 f.debug_struct("Stream")
102 .field("id", &self.id.val())
103 .field("connection", &self.conn)
104 .finish()
105 }
106}
107
108impl fmt::Display for Stream {
109 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
110 write!(f, "(Stream {}/{})", self.conn, self.id.val())
111 }
112}
113
114impl Stream {
115 pub(crate) fn new_inbound(
116 id: StreamId,
117 conn: connection::Id,
118 config: Arc<Config>,
119 credit: u32,
120 sender: mpsc::Sender<StreamCommand>,
121 ) -> Self {
122 Self {
123 id,
124 conn,
125 config: config.clone(),
126 sender,
127 flag: Flag::None,
128 shared: Arc::new(Mutex::new(Shared::new(DEFAULT_CREDIT, credit, config))),
129 }
130 }
131
132 pub(crate) fn new_outbound(
133 id: StreamId,
134 conn: connection::Id,
135 config: Arc<Config>,
136 window: u32,
137 sender: mpsc::Sender<StreamCommand>,
138 ) -> Self {
139 Self {
140 id,
141 conn,
142 config: config.clone(),
143 sender,
144 flag: Flag::None,
145 shared: Arc::new(Mutex::new(Shared::new(window, DEFAULT_CREDIT, config))),
146 }
147 }
148
149 pub fn id(&self) -> StreamId {
151 self.id
152 }
153
154 pub fn is_write_closed(&self) -> bool {
155 matches!(self.shared().state(), State::SendClosed)
156 }
157
158 pub fn is_closed(&self) -> bool {
159 matches!(self.shared().state(), State::Closed)
160 }
161
162 pub fn is_pending_ack(&self) -> bool {
164 self.shared().is_pending_ack()
165 }
166
167 pub(crate) fn set_flag(&mut self, flag: Flag) {
169 self.flag = flag
170 }
171
172 pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> {
173 self.shared.lock()
174 }
175
176 pub(crate) fn clone_shared(&self) -> Arc<Mutex<Shared>> {
177 self.shared.clone()
178 }
179
180 fn write_zero_err(&self) -> io::Error {
181 let msg = format!("{}/{}: connection is closed", self.conn, self.id);
182 io::Error::new(io::ErrorKind::WriteZero, msg)
183 }
184
185 fn add_flag(&mut self, header: &mut Header<Either<Data, WindowUpdate>>) {
187 match self.flag {
188 Flag::None => (),
189 Flag::Syn => {
190 header.syn();
191 self.flag = Flag::None
192 }
193 Flag::Ack => {
194 header.ack();
195 self.flag = Flag::None
196 }
197 }
198 }
199
200 fn send_window_update(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
203 #[allow(deprecated)]
206 if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) {
207 return Poll::Ready(Ok(()));
208 }
209
210 let mut shared = self.shared.lock();
211
212 if let Some(credit) = shared.next_window_update() {
213 ready!(self
214 .sender
215 .poll_ready(cx)
216 .map_err(|_| self.write_zero_err())?);
217
218 shared.window += credit;
219 drop(shared);
220
221 let mut frame = Frame::window_update(self.id, credit).right();
222 self.add_flag(frame.header_mut());
223 let cmd = StreamCommand::SendFrame(frame);
224 self.sender
225 .start_send(cmd)
226 .map_err(|_| self.write_zero_err())?;
227 }
228
229 Poll::Ready(Ok(()))
230 }
231}
232
233#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
235pub struct Packet(Vec<u8>);
236
237impl AsRef<[u8]> for Packet {
238 fn as_ref(&self) -> &[u8] {
239 self.0.as_ref()
240 }
241}
242
243impl futures::stream::Stream for Stream {
244 type Item = io::Result<Packet>;
245
246 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
247 if !self.config.read_after_close && self.sender.is_closed() {
248 return Poll::Ready(None);
249 }
250
251 match self.send_window_update(cx) {
252 Poll::Ready(Ok(())) => {}
253 Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
254 Poll::Pending => {}
256 }
257
258 let mut shared = self.shared();
259
260 if let Some(bytes) = shared.buffer.pop() {
261 let off = bytes.offset();
262 let mut vec = bytes.into_vec();
263 if off != 0 {
264 log::debug!(
269 "{}/{}: chunk has been partially consumed",
270 self.conn,
271 self.id
272 );
273 vec = vec.split_off(off)
274 }
275 return Poll::Ready(Some(Ok(Packet(vec))));
276 }
277
278 if !shared.state().can_read() {
280 log::debug!("{}/{}: eof", self.conn, self.id);
281 return Poll::Ready(None); }
283
284 shared.reader = Some(cx.waker().clone());
287
288 Poll::Pending
289 }
290}
291
292impl AsyncRead for Stream {
295 fn poll_read(
296 mut self: Pin<&mut Self>,
297 cx: &mut Context,
298 buf: &mut [u8],
299 ) -> Poll<io::Result<usize>> {
300 if !self.config.read_after_close && self.sender.is_closed() {
301 return Poll::Ready(Ok(0));
302 }
303
304 match self.send_window_update(cx) {
305 Poll::Ready(Ok(())) => {}
306 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
307 Poll::Pending => {}
309 }
310
311 let mut shared = self.shared();
313 let mut n = 0;
314 while let Some(chunk) = shared.buffer.front_mut() {
315 if chunk.is_empty() {
316 shared.buffer.pop();
317 continue;
318 }
319 let k = std::cmp::min(chunk.len(), buf.len() - n);
320 buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
321 n += k;
322 chunk.advance(k);
323 if n == buf.len() {
324 break;
325 }
326 }
327
328 if n > 0 {
329 log::trace!("{}/{}: read {} bytes", self.conn, self.id, n);
330 return Poll::Ready(Ok(n));
331 }
332
333 if !shared.state().can_read() {
335 log::debug!("{}/{}: eof", self.conn, self.id);
336 return Poll::Ready(Ok(0)); }
338
339 shared.reader = Some(cx.waker().clone());
342
343 Poll::Pending
344 }
345}
346
347impl AsyncWrite for Stream {
348 fn poll_write(
349 mut self: Pin<&mut Self>,
350 cx: &mut Context,
351 buf: &[u8],
352 ) -> Poll<io::Result<usize>> {
353 ready!(self
354 .sender
355 .poll_ready(cx)
356 .map_err(|_| self.write_zero_err())?);
357 let body = {
358 let mut shared = self.shared();
359 if !shared.state().can_write() {
360 log::debug!("{}/{}: can no longer write", self.conn, self.id);
361 return Poll::Ready(Err(self.write_zero_err()));
362 }
363 if shared.credit == 0 {
364 log::trace!("{}/{}: no more credit left", self.conn, self.id);
365 shared.writer = Some(cx.waker().clone());
366 return Poll::Pending;
367 }
368 let k = std::cmp::min(shared.credit as usize, buf.len());
369 let k = std::cmp::min(k, self.config.split_send_size);
370 shared.credit = shared.credit.saturating_sub(k as u32);
371 Vec::from(&buf[..k])
372 };
373 let n = body.len();
374 let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left();
375 self.add_flag(frame.header_mut());
376 log::trace!("{}/{}: write {} bytes", self.conn, self.id, n);
377
378 if frame.header().flags().contains(ACK) {
383 self.shared()
384 .update_state(self.conn, self.id, State::Open { acknowledged: true });
385 }
386
387 let cmd = StreamCommand::SendFrame(frame);
388 self.sender
389 .start_send(cmd)
390 .map_err(|_| self.write_zero_err())?;
391 Poll::Ready(Ok(n))
392 }
393
394 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
395 self.sender
396 .poll_flush_unpin(cx)
397 .map_err(|_| self.write_zero_err())
398 }
399
400 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
401 if self.is_closed() {
402 return Poll::Ready(Ok(()));
403 }
404 ready!(self
405 .sender
406 .poll_ready(cx)
407 .map_err(|_| self.write_zero_err())?);
408 let ack = if self.flag == Flag::Ack {
409 self.flag = Flag::None;
410 true
411 } else {
412 false
413 };
414 log::trace!("{}/{}: close", self.conn, self.id);
415 let cmd = StreamCommand::CloseStream { ack };
416 self.sender
417 .start_send(cmd)
418 .map_err(|_| self.write_zero_err())?;
419 self.shared()
420 .update_state(self.conn, self.id, State::SendClosed);
421 Poll::Ready(Ok(()))
422 }
423}
424
425#[derive(Debug)]
426pub(crate) struct Shared {
427 state: State,
428 pub(crate) window: u32,
429 pub(crate) credit: u32,
430 pub(crate) buffer: Chunks,
431 pub(crate) reader: Option<Waker>,
432 pub(crate) writer: Option<Waker>,
433 config: Arc<Config>,
434}
435
436impl Shared {
437 fn new(window: u32, credit: u32, config: Arc<Config>) -> Self {
438 Shared {
439 state: State::Open {
440 acknowledged: false,
441 },
442 window,
443 credit,
444 buffer: Chunks::new(),
445 reader: None,
446 writer: None,
447 config,
448 }
449 }
450
451 pub(crate) fn state(&self) -> State {
452 self.state
453 }
454
455 pub(crate) fn update_state(
457 &mut self,
458 cid: connection::Id,
459 sid: StreamId,
460 next: State,
461 ) -> State {
462 use self::State::*;
463
464 let current = self.state;
465
466 match (current, next) {
467 (Closed, _) => {}
468 (Open { .. }, _) => self.state = next,
469 (RecvClosed, Closed) => self.state = Closed,
470 (RecvClosed, Open { .. }) => {}
471 (RecvClosed, RecvClosed) => {}
472 (RecvClosed, SendClosed) => self.state = Closed,
473 (SendClosed, Closed) => self.state = Closed,
474 (SendClosed, Open { .. }) => {}
475 (SendClosed, RecvClosed) => self.state = Closed,
476 (SendClosed, SendClosed) => {}
477 }
478
479 log::trace!(
480 "{}/{}: update state: (from {:?} to {:?} -> {:?})",
481 cid,
482 sid,
483 current,
484 next,
485 self.state
486 );
487
488 current }
490
491 pub(crate) fn next_window_update(&mut self) -> Option<u32> {
499 if !self.state.can_read() {
500 return None;
501 }
502
503 let new_credit = match self.config.window_update_mode {
504 #[allow(deprecated)]
505 WindowUpdateMode::OnReceive => {
506 debug_assert!(self.config.receive_window >= self.window);
507
508 self.config.receive_window.saturating_sub(self.window)
509 }
510 WindowUpdateMode::OnRead => {
511 debug_assert!(self.config.receive_window >= self.window);
512 let bytes_received = self.config.receive_window.saturating_sub(self.window);
513 let buffer_len: u32 = self.buffer.len().try_into().unwrap_or(std::u32::MAX);
514
515 bytes_received.saturating_sub(buffer_len)
516 }
517 };
518
519 if new_credit >= self.config.receive_window / 2 {
525 Some(new_credit)
526 } else {
527 None
528 }
529 }
530
531 pub fn is_pending_ack(&self) -> bool {
533 matches!(
534 self.state(),
535 State::Open {
536 acknowledged: false
537 }
538 )
539 }
540}