futures_rustls/common/
mod.rs

1mod handshake;
2
3use futures_io::{AsyncRead, AsyncWrite};
4pub(crate) use handshake::{IoSession, MidHandshake};
5use rustls::{ConnectionCommon, SideData};
6use std::io::{self, IoSlice, Read, Write};
7use std::ops::{Deref, DerefMut};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11#[derive(Debug)]
12pub enum TlsState {
13    #[cfg(feature = "early-data")]
14    EarlyData(usize, Vec<u8>),
15    Stream,
16    ReadShutdown,
17    WriteShutdown,
18    FullyShutdown,
19}
20
21impl TlsState {
22    #[inline]
23    pub fn shutdown_read(&mut self) {
24        match *self {
25            TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
26            _ => *self = TlsState::ReadShutdown,
27        }
28    }
29
30    #[inline]
31    pub fn shutdown_write(&mut self) {
32        match *self {
33            TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
34            _ => *self = TlsState::WriteShutdown,
35        }
36    }
37
38    #[inline]
39    pub fn writeable(&self) -> bool {
40        !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
41    }
42
43    #[inline]
44    pub fn readable(&self) -> bool {
45        !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
46    }
47
48    #[inline]
49    #[cfg(feature = "early-data")]
50    pub fn is_early_data(&self) -> bool {
51        matches!(self, TlsState::EarlyData(..))
52    }
53
54    #[inline]
55    #[cfg(not(feature = "early-data"))]
56    pub const fn is_early_data(&self) -> bool {
57        false
58    }
59}
60
61pub struct Stream<'a, IO, C> {
62    pub io: &'a mut IO,
63    pub session: &'a mut C,
64    pub eof: bool,
65}
66
67impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
68where
69    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
70    SD: SideData,
71{
72    pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
73        Stream {
74            io,
75            session,
76            // The state so far is only used to detect EOF, so either Stream
77            // or EarlyData state should both be all right.
78            eof: false,
79        }
80    }
81
82    pub fn set_eof(mut self, eof: bool) -> Self {
83        self.eof = eof;
84        self
85    }
86
87    pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
88        Pin::new(self)
89    }
90
91    pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
92        struct Reader<'a, 'b, T> {
93            io: &'a mut T,
94            cx: &'a mut Context<'b>,
95        }
96
97        impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
98            #[inline]
99            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
100                match Pin::new(&mut self.io).poll_read(self.cx, buf) {
101                    Poll::Ready(Ok(n)) => Ok(n),
102                    Poll::Ready(Err(err)) => Err(err),
103                    Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
104                }
105            }
106        }
107
108        let mut reader = Reader { io: self.io, cx };
109
110        let n = match self.session.read_tls(&mut reader) {
111            Ok(n) => n,
112            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
113            Err(err) => return Poll::Ready(Err(err)),
114        };
115
116        let stats = self.session.process_new_packets().map_err(|err| {
117            // In case we have an alert to send describing this error,
118            // try a last-gasp write -- but don't predate the primary
119            // error.
120            let _ = self.write_io(cx);
121
122            io::Error::new(io::ErrorKind::InvalidData, err)
123        })?;
124
125        if stats.peer_has_closed() && self.session.is_handshaking() {
126            return Poll::Ready(Err(io::Error::new(
127                io::ErrorKind::UnexpectedEof,
128                "tls handshake alert",
129            )));
130        }
131
132        Poll::Ready(Ok(n))
133    }
134
135    pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
136        struct Writer<'a, 'b, T> {
137            io: &'a mut T,
138            cx: &'a mut Context<'b>,
139        }
140
141        impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
142            #[inline]
143            fn poll_with<U>(
144                &mut self,
145                f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
146            ) -> io::Result<U> {
147                match f(Pin::new(&mut self.io), self.cx) {
148                    Poll::Ready(result) => result,
149                    Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
150                }
151            }
152        }
153
154        impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
155            #[inline]
156            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
157                self.poll_with(|io, cx| io.poll_write(cx, buf))
158            }
159
160            #[inline]
161            fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
162                self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
163            }
164
165            fn flush(&mut self) -> io::Result<()> {
166                self.poll_with(|io, cx| io.poll_flush(cx))
167            }
168        }
169
170        let mut writer = Writer { io: self.io, cx };
171
172        match self.session.write_tls(&mut writer) {
173            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
174            result => Poll::Ready(result),
175        }
176    }
177
178    pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
179        let mut wrlen = 0;
180        let mut rdlen = 0;
181
182        loop {
183            let mut write_would_block = false;
184            let mut read_would_block = false;
185            let mut need_flush = false;
186
187            while self.session.wants_write() {
188                match self.write_io(cx) {
189                    Poll::Ready(Ok(n)) => {
190                        wrlen += n;
191                        need_flush = true;
192                    }
193                    Poll::Pending => {
194                        write_would_block = true;
195                        break;
196                    }
197                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
198                }
199            }
200
201            if need_flush {
202                match Pin::new(&mut self.io).poll_flush(cx) {
203                    Poll::Ready(Ok(())) => (),
204                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
205                    Poll::Pending => write_would_block = true,
206                }
207            }
208
209            while !self.eof && self.session.wants_read() {
210                match self.read_io(cx) {
211                    Poll::Ready(Ok(0)) => self.eof = true,
212                    Poll::Ready(Ok(n)) => rdlen += n,
213                    Poll::Pending => {
214                        read_would_block = true;
215                        break;
216                    }
217                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
218                }
219            }
220
221            return match (self.eof, self.session.is_handshaking()) {
222                (true, true) => {
223                    let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
224                    Poll::Ready(Err(err))
225                }
226                (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
227                (_, true) if write_would_block || read_would_block => {
228                    if rdlen != 0 || wrlen != 0 {
229                        Poll::Ready(Ok((rdlen, wrlen)))
230                    } else {
231                        Poll::Pending
232                    }
233                }
234                (..) => continue,
235            };
236        }
237    }
238}
239
240impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
241where
242    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
243    SD: SideData,
244{
245    fn poll_read(
246        mut self: Pin<&mut Self>,
247        cx: &mut Context<'_>,
248        buf: &mut [u8],
249    ) -> Poll<io::Result<usize>> {
250        let mut io_pending = false;
251
252        // read a packet
253        while !self.eof && self.session.wants_read() {
254            match self.read_io(cx) {
255                Poll::Ready(Ok(0)) => {
256                    break;
257                }
258                Poll::Ready(Ok(_)) => (),
259                Poll::Pending => {
260                    io_pending = true;
261                    break;
262                }
263                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
264            }
265        }
266
267        match self.session.reader().read(buf) {
268            // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
269            // connection with a `CloseNotify` message and no more data will be forthcoming.
270            //
271            // Rustls yielded more data: advance the buffer, then see if more data is coming.
272            //
273            // We don't need to modify `self.eof` here, because it is only a temporary mark.
274            // rustls will only return 0 if is has received `CloseNotify`,
275            // in which case no additional processing is required.
276            Ok(n) => Poll::Ready(Ok(n)),
277
278            // Rustls doesn't have more data to yield, but it believes the connection is open.
279            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
280                if !io_pending {
281                    // If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
282                    // but if it does, we can try again.
283                    //
284                    // If the rustls state is abnormal, it may cause a cyclic wakeup.
285                    // but tokio's cooperative budget will prevent infinite wakeup.
286                    cx.waker().wake_by_ref();
287                }
288
289                Poll::Pending
290            }
291
292            Err(err) => Poll::Ready(Err(err)),
293        }
294    }
295}
296
297impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C>
298where
299    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
300    SD: SideData,
301{
302    fn poll_write(
303        mut self: Pin<&mut Self>,
304        cx: &mut Context,
305        buf: &[u8],
306    ) -> Poll<io::Result<usize>> {
307        let mut pos = 0;
308
309        while pos != buf.len() {
310            let mut would_block = false;
311
312            match self.session.writer().write(&buf[pos..]) {
313                Ok(n) => pos += n,
314                Err(err) => return Poll::Ready(Err(err)),
315            };
316
317            while self.session.wants_write() {
318                match self.write_io(cx) {
319                    Poll::Ready(Ok(0)) | Poll::Pending => {
320                        would_block = true;
321                        break;
322                    }
323                    Poll::Ready(Ok(_)) => (),
324                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
325                }
326            }
327
328            return match (pos, would_block) {
329                (0, true) => Poll::Pending,
330                (n, true) => Poll::Ready(Ok(n)),
331                (_, false) => continue,
332            };
333        }
334
335        Poll::Ready(Ok(pos))
336    }
337
338    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
339        self.session.writer().flush()?;
340        while self.session.wants_write() {
341            ready!(self.write_io(cx))?;
342        }
343        Pin::new(&mut self.io).poll_flush(cx)
344    }
345
346    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
347        while self.session.wants_write() {
348            ready!(self.write_io(cx))?;
349        }
350        Pin::new(&mut self.io).poll_close(cx)
351    }
352}
353
354/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an
355/// associated [`Context`].
356///
357/// Turns `Poll::Pending` into `WouldBlock`.
358pub struct SyncReadAdapter<'a, 'b, T> {
359    pub io: &'a mut T,
360    pub cx: &'a mut Context<'b>,
361}
362
363impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
364    #[inline]
365    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
366        match Pin::new(&mut self.io).poll_read(self.cx, buf) {
367            Poll::Ready(Ok(n)) => Ok(n),
368            Poll::Ready(Err(err)) => Err(err),
369            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
370        }
371    }
372}
373
374#[cfg(test)]
375mod test_stream;