1use std::io::{self, IoSlice, Read, Write};
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use rustls::{ConnectionCommon, SideData};
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9mod handshake;
10pub(crate) use handshake::{IoSession, MidHandshake};
11
12#[derive(Debug)]
13pub enum TlsState {
14 #[cfg(feature = "early-data")]
15 EarlyData(usize, Vec<u8>),
16 Stream,
17 ReadShutdown,
18 WriteShutdown,
19 FullyShutdown,
20}
21
22impl TlsState {
23 #[inline]
24 pub fn shutdown_read(&mut self) {
25 match *self {
26 TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
27 _ => *self = TlsState::ReadShutdown,
28 }
29 }
30
31 #[inline]
32 pub fn shutdown_write(&mut self) {
33 match *self {
34 TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
35 _ => *self = TlsState::WriteShutdown,
36 }
37 }
38
39 #[inline]
40 pub fn writeable(&self) -> bool {
41 !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
42 }
43
44 #[inline]
45 pub fn readable(&self) -> bool {
46 !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
47 }
48
49 #[inline]
50 #[cfg(feature = "early-data")]
51 pub fn is_early_data(&self) -> bool {
52 matches!(self, TlsState::EarlyData(..))
53 }
54
55 #[inline]
56 #[cfg(not(feature = "early-data"))]
57 pub const fn is_early_data(&self) -> bool {
58 false
59 }
60}
61
62pub struct Stream<'a, IO, C> {
63 pub io: &'a mut IO,
64 pub session: &'a mut C,
65 pub eof: bool,
66}
67
68impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
69where
70 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
71 SD: SideData,
72{
73 pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
74 Stream {
75 io,
76 session,
77 eof: false,
80 }
81 }
82
83 pub fn set_eof(mut self, eof: bool) -> Self {
84 self.eof = eof;
85 self
86 }
87
88 pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
89 Pin::new(self)
90 }
91
92 pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
93 let mut reader = SyncReadAdapter { io: self.io, cx };
94
95 let n = match self.session.read_tls(&mut reader) {
96 Ok(n) => n,
97 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
98 Err(err) => return Poll::Ready(Err(err)),
99 };
100
101 let stats = self.session.process_new_packets().map_err(|err| {
102 let _ = self.write_io(cx);
106
107 io::Error::new(io::ErrorKind::InvalidData, err)
108 })?;
109
110 if stats.peer_has_closed() && self.session.is_handshaking() {
111 return Poll::Ready(Err(io::Error::new(
112 io::ErrorKind::UnexpectedEof,
113 "tls handshake alert",
114 )));
115 }
116
117 Poll::Ready(Ok(n))
118 }
119
120 pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
121 let mut writer = SyncWriteAdapter { io: self.io, cx };
122
123 match self.session.write_tls(&mut writer) {
124 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
125 result => Poll::Ready(result),
126 }
127 }
128
129 pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
130 let mut wrlen = 0;
131 let mut rdlen = 0;
132
133 loop {
134 let mut write_would_block = false;
135 let mut read_would_block = false;
136 let mut need_flush = false;
137
138 while self.session.wants_write() {
139 match self.write_io(cx) {
140 Poll::Ready(Ok(n)) => {
141 wrlen += n;
142 need_flush = true;
143 }
144 Poll::Pending => {
145 write_would_block = true;
146 break;
147 }
148 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
149 }
150 }
151
152 if need_flush {
153 match Pin::new(&mut self.io).poll_flush(cx) {
154 Poll::Ready(Ok(())) => (),
155 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
156 Poll::Pending => write_would_block = true,
157 }
158 }
159
160 while !self.eof && self.session.wants_read() {
161 match self.read_io(cx) {
162 Poll::Ready(Ok(0)) => self.eof = true,
163 Poll::Ready(Ok(n)) => rdlen += n,
164 Poll::Pending => {
165 read_would_block = true;
166 break;
167 }
168 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
169 }
170 }
171
172 return match (self.eof, self.session.is_handshaking()) {
173 (true, true) => {
174 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
175 Poll::Ready(Err(err))
176 }
177 (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
178 (_, true) if write_would_block || read_would_block => {
179 if rdlen != 0 || wrlen != 0 {
180 Poll::Ready(Ok((rdlen, wrlen)))
181 } else {
182 Poll::Pending
183 }
184 }
185 (..) => continue,
186 };
187 }
188 }
189}
190
191impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
192where
193 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
194 SD: SideData,
195{
196 fn poll_read(
197 mut self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 buf: &mut ReadBuf<'_>,
200 ) -> Poll<io::Result<()>> {
201 let mut io_pending = false;
202
203 while !self.eof && self.session.wants_read() {
205 match self.read_io(cx) {
206 Poll::Ready(Ok(0)) => {
207 break;
208 }
209 Poll::Ready(Ok(_)) => (),
210 Poll::Pending => {
211 io_pending = true;
212 break;
213 }
214 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
215 }
216 }
217
218 match self.session.reader().read(buf.initialize_unfilled()) {
219 Ok(n) => {
228 buf.advance(n);
229 Poll::Ready(Ok(()))
230 }
231
232 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
234 if !io_pending {
235 cx.waker().wake_by_ref();
241 }
242
243 Poll::Pending
244 }
245
246 Err(err) => Poll::Ready(Err(err)),
247 }
248 }
249}
250
251impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C>
252where
253 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
254 SD: SideData,
255{
256 fn poll_write(
257 mut self: Pin<&mut Self>,
258 cx: &mut Context,
259 buf: &[u8],
260 ) -> Poll<io::Result<usize>> {
261 let mut pos = 0;
262
263 while pos != buf.len() {
264 let mut would_block = false;
265
266 match self.session.writer().write(&buf[pos..]) {
267 Ok(n) => pos += n,
268 Err(err) => return Poll::Ready(Err(err)),
269 };
270
271 while self.session.wants_write() {
272 match self.write_io(cx) {
273 Poll::Ready(Ok(0)) | Poll::Pending => {
274 would_block = true;
275 break;
276 }
277 Poll::Ready(Ok(_)) => (),
278 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
279 }
280 }
281
282 return match (pos, would_block) {
283 (0, true) => Poll::Pending,
284 (n, true) => Poll::Ready(Ok(n)),
285 (_, false) => continue,
286 };
287 }
288
289 Poll::Ready(Ok(pos))
290 }
291
292 fn poll_write_vectored(
293 mut self: Pin<&mut Self>,
294 cx: &mut Context<'_>,
295 bufs: &[IoSlice<'_>],
296 ) -> Poll<io::Result<usize>> {
297 if bufs.iter().all(|buf| buf.is_empty()) {
298 return Poll::Ready(Ok(0));
299 }
300
301 loop {
302 let mut would_block = false;
303 let written = self.session.writer().write_vectored(bufs)?;
304
305 while self.session.wants_write() {
306 match self.write_io(cx) {
307 Poll::Ready(Ok(0)) | Poll::Pending => {
308 would_block = true;
309 break;
310 }
311 Poll::Ready(Ok(_)) => (),
312 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
313 }
314 }
315
316 return match (written, would_block) {
317 (0, true) => Poll::Pending,
318 (0, false) => continue,
319 (n, _) => Poll::Ready(Ok(n)),
320 };
321 }
322 }
323
324 #[inline]
325 fn is_write_vectored(&self) -> bool {
326 true
327 }
328
329 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
330 self.session.writer().flush()?;
331 while self.session.wants_write() {
332 ready!(self.write_io(cx))?;
333 }
334 Pin::new(&mut self.io).poll_flush(cx)
335 }
336
337 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
338 while self.session.wants_write() {
339 ready!(self.write_io(cx))?;
340 }
341
342 Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) {
343 Ok(()) => Ok(()),
344 Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()),
346 Err(err) => Err(err),
347 })
348 }
349}
350
351pub struct SyncReadAdapter<'a, 'b, T> {
356 pub io: &'a mut T,
357 pub cx: &'a mut Context<'b>,
358}
359
360impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
361 #[inline]
362 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
363 let mut buf = ReadBuf::new(buf);
364 match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
365 Poll::Ready(Ok(())) => Ok(buf.filled().len()),
366 Poll::Ready(Err(err)) => Err(err),
367 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
368 }
369 }
370}
371
372pub struct SyncWriteAdapter<'a, 'b, T> {
377 pub io: &'a mut T,
378 pub cx: &'a mut Context<'b>,
379}
380
381impl<'a, 'b, T: Unpin> SyncWriteAdapter<'a, 'b, T> {
382 #[inline]
383 fn poll_with<U>(
384 &mut self,
385 f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
386 ) -> io::Result<U> {
387 match f(Pin::new(self.io), self.cx) {
388 Poll::Ready(result) => result,
389 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
390 }
391 }
392}
393
394impl<'a, 'b, T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'a, 'b, T> {
395 #[inline]
396 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
397 self.poll_with(|io, cx| io.poll_write(cx, buf))
398 }
399
400 #[inline]
401 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
402 self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
403 }
404
405 fn flush(&mut self) -> io::Result<()> {
406 self.poll_with(|io, cx| io.poll_flush(cx))
407 }
408}
409
410#[cfg(test)]
411mod test_stream;