tokio_rustls/common/
mod.rs

1use std::io::{self, IoSlice, Read, Write};
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use rustls::{ConnectionCommon, SideData};
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9mod handshake;
10pub(crate) use handshake::{IoSession, MidHandshake};
11
12#[derive(Debug)]
13pub enum TlsState {
14    #[cfg(feature = "early-data")]
15    EarlyData(usize, Vec<u8>),
16    Stream,
17    ReadShutdown,
18    WriteShutdown,
19    FullyShutdown,
20}
21
22impl TlsState {
23    #[inline]
24    pub fn shutdown_read(&mut self) {
25        match *self {
26            TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
27            _ => *self = TlsState::ReadShutdown,
28        }
29    }
30
31    #[inline]
32    pub fn shutdown_write(&mut self) {
33        match *self {
34            TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
35            _ => *self = TlsState::WriteShutdown,
36        }
37    }
38
39    #[inline]
40    pub fn writeable(&self) -> bool {
41        !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
42    }
43
44    #[inline]
45    pub fn readable(&self) -> bool {
46        !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
47    }
48
49    #[inline]
50    #[cfg(feature = "early-data")]
51    pub fn is_early_data(&self) -> bool {
52        matches!(self, TlsState::EarlyData(..))
53    }
54
55    #[inline]
56    #[cfg(not(feature = "early-data"))]
57    pub const fn is_early_data(&self) -> bool {
58        false
59    }
60}
61
62pub struct Stream<'a, IO, C> {
63    pub io: &'a mut IO,
64    pub session: &'a mut C,
65    pub eof: bool,
66}
67
68impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
69where
70    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
71    SD: SideData,
72{
73    pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
74        Stream {
75            io,
76            session,
77            // The state so far is only used to detect EOF, so either Stream
78            // or EarlyData state should both be all right.
79            eof: false,
80        }
81    }
82
83    pub fn set_eof(mut self, eof: bool) -> Self {
84        self.eof = eof;
85        self
86    }
87
88    pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
89        Pin::new(self)
90    }
91
92    pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
93        let mut reader = SyncReadAdapter { io: self.io, cx };
94
95        let n = match self.session.read_tls(&mut reader) {
96            Ok(n) => n,
97            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
98            Err(err) => return Poll::Ready(Err(err)),
99        };
100
101        let stats = self.session.process_new_packets().map_err(|err| {
102            // In case we have an alert to send describing this error,
103            // try a last-gasp write -- but don't predate the primary
104            // error.
105            let _ = self.write_io(cx);
106
107            io::Error::new(io::ErrorKind::InvalidData, err)
108        })?;
109
110        if stats.peer_has_closed() && self.session.is_handshaking() {
111            return Poll::Ready(Err(io::Error::new(
112                io::ErrorKind::UnexpectedEof,
113                "tls handshake alert",
114            )));
115        }
116
117        Poll::Ready(Ok(n))
118    }
119
120    pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
121        let mut writer = SyncWriteAdapter { io: self.io, cx };
122
123        match self.session.write_tls(&mut writer) {
124            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
125            result => Poll::Ready(result),
126        }
127    }
128
129    pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
130        let mut wrlen = 0;
131        let mut rdlen = 0;
132
133        loop {
134            let mut write_would_block = false;
135            let mut read_would_block = false;
136            let mut need_flush = false;
137
138            while self.session.wants_write() {
139                match self.write_io(cx) {
140                    Poll::Ready(Ok(n)) => {
141                        wrlen += n;
142                        need_flush = true;
143                    }
144                    Poll::Pending => {
145                        write_would_block = true;
146                        break;
147                    }
148                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
149                }
150            }
151
152            if need_flush {
153                match Pin::new(&mut self.io).poll_flush(cx) {
154                    Poll::Ready(Ok(())) => (),
155                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
156                    Poll::Pending => write_would_block = true,
157                }
158            }
159
160            while !self.eof && self.session.wants_read() {
161                match self.read_io(cx) {
162                    Poll::Ready(Ok(0)) => self.eof = true,
163                    Poll::Ready(Ok(n)) => rdlen += n,
164                    Poll::Pending => {
165                        read_would_block = true;
166                        break;
167                    }
168                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
169                }
170            }
171
172            return match (self.eof, self.session.is_handshaking()) {
173                (true, true) => {
174                    let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
175                    Poll::Ready(Err(err))
176                }
177                (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
178                (_, true) if write_would_block || read_would_block => {
179                    if rdlen != 0 || wrlen != 0 {
180                        Poll::Ready(Ok((rdlen, wrlen)))
181                    } else {
182                        Poll::Pending
183                    }
184                }
185                (..) => continue,
186            };
187        }
188    }
189}
190
191impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
192where
193    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
194    SD: SideData,
195{
196    fn poll_read(
197        mut self: Pin<&mut Self>,
198        cx: &mut Context<'_>,
199        buf: &mut ReadBuf<'_>,
200    ) -> Poll<io::Result<()>> {
201        let mut io_pending = false;
202
203        // read a packet
204        while !self.eof && self.session.wants_read() {
205            match self.read_io(cx) {
206                Poll::Ready(Ok(0)) => {
207                    break;
208                }
209                Poll::Ready(Ok(_)) => (),
210                Poll::Pending => {
211                    io_pending = true;
212                    break;
213                }
214                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
215            }
216        }
217
218        match self.session.reader().read(buf.initialize_unfilled()) {
219            // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
220            // connection with a `CloseNotify` message and no more data will be forthcoming.
221            //
222            // Rustls yielded more data: advance the buffer, then see if more data is coming.
223            //
224            // We don't need to modify `self.eof` here, because it is only a temporary mark.
225            // rustls will only return 0 if is has received `CloseNotify`,
226            // in which case no additional processing is required.
227            Ok(n) => {
228                buf.advance(n);
229                Poll::Ready(Ok(()))
230            }
231
232            // Rustls doesn't have more data to yield, but it believes the connection is open.
233            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
234                if !io_pending {
235                    // If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
236                    // but if it does, we can try again.
237                    //
238                    // If the rustls state is abnormal, it may cause a cyclic wakeup.
239                    // but tokio's cooperative budget will prevent infinite wakeup.
240                    cx.waker().wake_by_ref();
241                }
242
243                Poll::Pending
244            }
245
246            Err(err) => Poll::Ready(Err(err)),
247        }
248    }
249}
250
251impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C>
252where
253    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
254    SD: SideData,
255{
256    fn poll_write(
257        mut self: Pin<&mut Self>,
258        cx: &mut Context,
259        buf: &[u8],
260    ) -> Poll<io::Result<usize>> {
261        let mut pos = 0;
262
263        while pos != buf.len() {
264            let mut would_block = false;
265
266            match self.session.writer().write(&buf[pos..]) {
267                Ok(n) => pos += n,
268                Err(err) => return Poll::Ready(Err(err)),
269            };
270
271            while self.session.wants_write() {
272                match self.write_io(cx) {
273                    Poll::Ready(Ok(0)) | Poll::Pending => {
274                        would_block = true;
275                        break;
276                    }
277                    Poll::Ready(Ok(_)) => (),
278                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
279                }
280            }
281
282            return match (pos, would_block) {
283                (0, true) => Poll::Pending,
284                (n, true) => Poll::Ready(Ok(n)),
285                (_, false) => continue,
286            };
287        }
288
289        Poll::Ready(Ok(pos))
290    }
291
292    fn poll_write_vectored(
293        mut self: Pin<&mut Self>,
294        cx: &mut Context<'_>,
295        bufs: &[IoSlice<'_>],
296    ) -> Poll<io::Result<usize>> {
297        if bufs.iter().all(|buf| buf.is_empty()) {
298            return Poll::Ready(Ok(0));
299        }
300
301        loop {
302            let mut would_block = false;
303            let written = self.session.writer().write_vectored(bufs)?;
304
305            while self.session.wants_write() {
306                match self.write_io(cx) {
307                    Poll::Ready(Ok(0)) | Poll::Pending => {
308                        would_block = true;
309                        break;
310                    }
311                    Poll::Ready(Ok(_)) => (),
312                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
313                }
314            }
315
316            return match (written, would_block) {
317                (0, true) => Poll::Pending,
318                (0, false) => continue,
319                (n, _) => Poll::Ready(Ok(n)),
320            };
321        }
322    }
323
324    #[inline]
325    fn is_write_vectored(&self) -> bool {
326        true
327    }
328
329    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
330        self.session.writer().flush()?;
331        while self.session.wants_write() {
332            ready!(self.write_io(cx))?;
333        }
334        Pin::new(&mut self.io).poll_flush(cx)
335    }
336
337    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
338        while self.session.wants_write() {
339            ready!(self.write_io(cx))?;
340        }
341
342        Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) {
343            Ok(()) => Ok(()),
344            // When trying to shutdown, not being connected seems fine
345            Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()),
346            Err(err) => Err(err),
347        })
348    }
349}
350
351/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an
352/// associated [`Context`].
353///
354/// Turns `Poll::Pending` into `WouldBlock`.
355pub struct SyncReadAdapter<'a, 'b, T> {
356    pub io: &'a mut T,
357    pub cx: &'a mut Context<'b>,
358}
359
360impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
361    #[inline]
362    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
363        let mut buf = ReadBuf::new(buf);
364        match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
365            Poll::Ready(Ok(())) => Ok(buf.filled().len()),
366            Poll::Ready(Err(err)) => Err(err),
367            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
368        }
369    }
370}
371
372/// An adapter that implements a [`Write`] interface for [`AsyncWrite`] types and an
373/// associated [`Context`].
374///
375/// Turns `Poll::Pending` into `WouldBlock`.
376pub struct SyncWriteAdapter<'a, 'b, T> {
377    pub io: &'a mut T,
378    pub cx: &'a mut Context<'b>,
379}
380
381impl<'a, 'b, T: Unpin> SyncWriteAdapter<'a, 'b, T> {
382    #[inline]
383    fn poll_with<U>(
384        &mut self,
385        f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
386    ) -> io::Result<U> {
387        match f(Pin::new(self.io), self.cx) {
388            Poll::Ready(result) => result,
389            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
390        }
391    }
392}
393
394impl<'a, 'b, T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'a, 'b, T> {
395    #[inline]
396    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
397        self.poll_with(|io, cx| io.poll_write(cx, buf))
398    }
399
400    #[inline]
401    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
402        self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
403    }
404
405    fn flush(&mut self) -> io::Result<()> {
406        self.poll_with(|io, cx| io.poll_flush(cx))
407    }
408}
409
410#[cfg(test)]
411mod test_stream;