tokio_rustls/
lib.rs

1//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/rustls/rustls).
2//!
3//! # Why do I need to call `poll_flush`?
4//!
5//! Most TLS implementations will have an internal buffer to improve throughput,
6//! and rustls is no exception.
7//!
8//! When we write data to `TlsStream`, we always write rustls buffer first,
9//! then take out rustls encrypted data packet, and write it to data channel (like TcpStream).
10//! When data channel is pending, some data may remain in rustls buffer.
11//!
12//! `tokio-rustls` To keep it simple and correct, [TlsStream] will behave like `BufWriter`.
13//! For `TlsStream<TcpStream>`, this means that data written by `poll_write` is not guaranteed to be written to `TcpStream`.
14//! You must call `poll_flush` to ensure that it is written to `TcpStream`.
15//!
16//! You should call `poll_flush` at the appropriate time,
17//! such as when a period of `poll_write` write is complete and there is no more data to write.
18//!
19//! ## Why don't we write during `poll_read`?
20//!
21//! We did this in the early days of `tokio-rustls`, but it caused some bugs.
22//! We can solve these bugs through some solutions, but this will cause performance degradation (reverse false wakeup).
23//!
24//! And reverse write will also prevent us implement full duplex in the future.
25//!
26//! see <https://github.com/tokio-rs/tls/issues/40>
27//!
28//! ## Why can't we handle it like `native-tls`?
29//!
30//! When data channel returns to pending, `native-tls` will falsely report the number of bytes it consumes.
31//! This means that if data written by `poll_write` is not actually written to data channel, it will not return `Ready`.
32//! Thus avoiding the call of `poll_flush`.
33//!
34//! but which does not conform to convention of `AsyncWrite` trait.
35//! This means that if you give inconsistent data in two `poll_write`, it may cause unexpected behavior.
36//!
37//! see <https://github.com/tokio-rs/tls/issues/41>
38
39use std::future::Future;
40use std::io;
41#[cfg(unix)]
42use std::os::unix::io::{AsRawFd, RawFd};
43#[cfg(windows)]
44use std::os::windows::io::{AsRawSocket, RawSocket};
45use std::pin::Pin;
46use std::sync::Arc;
47use std::task::{Context, Poll};
48
49pub use rustls;
50use rustls::server::AcceptedAlert;
51use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
52use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
53
54macro_rules! ready {
55    ( $e:expr ) => {
56        match $e {
57            std::task::Poll::Ready(t) => t,
58            std::task::Poll::Pending => return std::task::Poll::Pending,
59        }
60    };
61}
62
63pub mod client;
64mod common;
65use common::{MidHandshake, TlsState};
66pub mod server;
67
68/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
69#[derive(Clone)]
70pub struct TlsConnector {
71    inner: Arc<ClientConfig>,
72    #[cfg(feature = "early-data")]
73    early_data: bool,
74}
75
76/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
77#[derive(Clone)]
78pub struct TlsAcceptor {
79    inner: Arc<ServerConfig>,
80}
81
82impl From<Arc<ClientConfig>> for TlsConnector {
83    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
84        TlsConnector {
85            inner,
86            #[cfg(feature = "early-data")]
87            early_data: false,
88        }
89    }
90}
91
92impl From<Arc<ServerConfig>> for TlsAcceptor {
93    fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
94        TlsAcceptor { inner }
95    }
96}
97
98impl TlsConnector {
99    /// Enable 0-RTT.
100    ///
101    /// If you want to use 0-RTT,
102    /// You must also set `ClientConfig.enable_early_data` to `true`.
103    #[cfg(feature = "early-data")]
104    pub fn early_data(mut self, flag: bool) -> TlsConnector {
105        self.early_data = flag;
106        self
107    }
108
109    #[inline]
110    pub fn connect<IO>(&self, domain: pki_types::ServerName<'static>, stream: IO) -> Connect<IO>
111    where
112        IO: AsyncRead + AsyncWrite + Unpin,
113    {
114        self.connect_with(domain, stream, |_| ())
115    }
116
117    pub fn connect_with<IO, F>(
118        &self,
119        domain: pki_types::ServerName<'static>,
120        stream: IO,
121        f: F,
122    ) -> Connect<IO>
123    where
124        IO: AsyncRead + AsyncWrite + Unpin,
125        F: FnOnce(&mut ClientConnection),
126    {
127        let mut session = match ClientConnection::new(self.inner.clone(), domain) {
128            Ok(session) => session,
129            Err(error) => {
130                return Connect(MidHandshake::Error {
131                    io: stream,
132                    // TODO(eliza): should this really return an `io::Error`?
133                    // Probably not...
134                    error: io::Error::new(io::ErrorKind::Other, error),
135                });
136            }
137        };
138        f(&mut session);
139
140        Connect(MidHandshake::Handshaking(client::TlsStream {
141            io: stream,
142
143            #[cfg(not(feature = "early-data"))]
144            state: TlsState::Stream,
145
146            #[cfg(feature = "early-data")]
147            state: if self.early_data && session.early_data().is_some() {
148                TlsState::EarlyData(0, Vec::new())
149            } else {
150                TlsState::Stream
151            },
152
153            #[cfg(feature = "early-data")]
154            early_waker: None,
155
156            session,
157        }))
158    }
159}
160
161impl TlsAcceptor {
162    #[inline]
163    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
164    where
165        IO: AsyncRead + AsyncWrite + Unpin,
166    {
167        self.accept_with(stream, |_| ())
168    }
169
170    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
171    where
172        IO: AsyncRead + AsyncWrite + Unpin,
173        F: FnOnce(&mut ServerConnection),
174    {
175        let mut session = match ServerConnection::new(self.inner.clone()) {
176            Ok(session) => session,
177            Err(error) => {
178                return Accept(MidHandshake::Error {
179                    io: stream,
180                    // TODO(eliza): should this really return an `io::Error`?
181                    // Probably not...
182                    error: io::Error::new(io::ErrorKind::Other, error),
183                });
184            }
185        };
186        f(&mut session);
187
188        Accept(MidHandshake::Handshaking(server::TlsStream {
189            session,
190            io: stream,
191            state: TlsState::Stream,
192        }))
193    }
194}
195
196pub struct LazyConfigAcceptor<IO> {
197    acceptor: rustls::server::Acceptor,
198    io: Option<IO>,
199    alert: Option<(rustls::Error, AcceptedAlert)>,
200}
201
202impl<IO> LazyConfigAcceptor<IO>
203where
204    IO: AsyncRead + AsyncWrite + Unpin,
205{
206    #[inline]
207    pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
208        Self {
209            acceptor,
210            io: Some(io),
211            alert: None,
212        }
213    }
214
215    /// Takes back the client connection. Will return `None` if called more than once or if the
216    /// connection has been accepted.
217    ///
218    /// # Example
219    ///
220    /// ```no_run
221    /// # fn choose_server_config(
222    /// #     _: rustls::server::ClientHello,
223    /// # ) -> std::sync::Arc<rustls::ServerConfig> {
224    /// #     unimplemented!();
225    /// # }
226    /// # #[allow(unused_variables)]
227    /// # async fn listen() {
228    /// use tokio::io::AsyncWriteExt;
229    /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
230    /// let (stream, _) = listener.accept().await.unwrap();
231    ///
232    /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
233    /// tokio::pin!(acceptor);
234    ///
235    /// match acceptor.as_mut().await {
236    ///     Ok(start) => {
237    ///         let clientHello = start.client_hello();
238    ///         let config = choose_server_config(clientHello);
239    ///         let stream = start.into_stream(config).await.unwrap();
240    ///         // Proceed with handling the ServerConnection...
241    ///     }
242    ///     Err(err) => {
243    ///         if let Some(mut stream) = acceptor.take_io() {
244    ///             stream
245    ///                 .write_all(
246    ///                     format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
247    ///                         .as_bytes()
248    ///                 )
249    ///                 .await
250    ///                 .unwrap();
251    ///         }
252    ///     }
253    /// }
254    /// # }
255    /// ```
256    pub fn take_io(&mut self) -> Option<IO> {
257        self.io.take()
258    }
259}
260
261impl<IO> Future for LazyConfigAcceptor<IO>
262where
263    IO: AsyncRead + AsyncWrite + Unpin,
264{
265    type Output = Result<StartHandshake<IO>, io::Error>;
266
267    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
268        let this = self.get_mut();
269        loop {
270            let io = match this.io.as_mut() {
271                Some(io) => io,
272                None => {
273                    return Poll::Ready(Err(io::Error::new(
274                        io::ErrorKind::Other,
275                        "acceptor cannot be polled after acceptance",
276                    )))
277                }
278            };
279
280            if let Some((err, mut alert)) = this.alert.take() {
281                match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
282                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
283                        this.alert = Some((err, alert));
284                        return Poll::Pending;
285                    }
286                    Ok(0) | Err(_) => {
287                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
288                    }
289                    Ok(_) => {
290                        this.alert = Some((err, alert));
291                        continue;
292                    }
293                };
294            }
295
296            let mut reader = common::SyncReadAdapter { io, cx };
297            match this.acceptor.read_tls(&mut reader) {
298                Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
299                Ok(_) => {}
300                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
301                Err(e) => return Err(e).into(),
302            }
303
304            match this.acceptor.accept() {
305                Ok(Some(accepted)) => {
306                    let io = this.io.take().unwrap();
307                    return Poll::Ready(Ok(StartHandshake { accepted, io }));
308                }
309                Ok(None) => {}
310                Err((err, alert)) => {
311                    this.alert = Some((err, alert));
312                }
313            }
314        }
315    }
316}
317
318pub struct StartHandshake<IO> {
319    accepted: rustls::server::Accepted,
320    io: IO,
321}
322
323impl<IO> StartHandshake<IO>
324where
325    IO: AsyncRead + AsyncWrite + Unpin,
326{
327    pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
328        self.accepted.client_hello()
329    }
330
331    pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
332        self.into_stream_with(config, |_| ())
333    }
334
335    pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
336    where
337        F: FnOnce(&mut ServerConnection),
338    {
339        let mut conn = match self.accepted.into_connection(config) {
340            Ok(conn) => conn,
341            Err((error, alert)) => {
342                return Accept(MidHandshake::SendAlert {
343                    io: self.io,
344                    alert,
345                    // TODO(eliza): should this really return an `io::Error`?
346                    // Probably not...
347                    error: io::Error::new(io::ErrorKind::InvalidData, error),
348                });
349            }
350        };
351        f(&mut conn);
352
353        Accept(MidHandshake::Handshaking(server::TlsStream {
354            session: conn,
355            io: self.io,
356            state: TlsState::Stream,
357        }))
358    }
359}
360
361/// Future returned from `TlsConnector::connect` which will resolve
362/// once the connection handshake has finished.
363pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
364
365/// Future returned from `TlsAcceptor::accept` which will resolve
366/// once the accept handshake has finished.
367pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
368
369/// Like [Connect], but returns `IO` on failure.
370pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
371
372/// Like [Accept], but returns `IO` on failure.
373pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
374
375impl<IO> Connect<IO> {
376    #[inline]
377    pub fn into_fallible(self) -> FallibleConnect<IO> {
378        FallibleConnect(self.0)
379    }
380
381    pub fn get_ref(&self) -> Option<&IO> {
382        match &self.0 {
383            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
384            MidHandshake::SendAlert { io, .. } => Some(io),
385            MidHandshake::Error { io, .. } => Some(io),
386            MidHandshake::End => None,
387        }
388    }
389
390    pub fn get_mut(&mut self) -> Option<&mut IO> {
391        match &mut self.0 {
392            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
393            MidHandshake::SendAlert { io, .. } => Some(io),
394            MidHandshake::Error { io, .. } => Some(io),
395            MidHandshake::End => None,
396        }
397    }
398}
399
400impl<IO> Accept<IO> {
401    #[inline]
402    pub fn into_fallible(self) -> FallibleAccept<IO> {
403        FallibleAccept(self.0)
404    }
405
406    pub fn get_ref(&self) -> Option<&IO> {
407        match &self.0 {
408            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
409            MidHandshake::SendAlert { io, .. } => Some(io),
410            MidHandshake::Error { io, .. } => Some(io),
411            MidHandshake::End => None,
412        }
413    }
414
415    pub fn get_mut(&mut self) -> Option<&mut IO> {
416        match &mut self.0 {
417            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
418            MidHandshake::SendAlert { io, .. } => Some(io),
419            MidHandshake::Error { io, .. } => Some(io),
420            MidHandshake::End => None,
421        }
422    }
423}
424
425impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
426    type Output = io::Result<client::TlsStream<IO>>;
427
428    #[inline]
429    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
430        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
431    }
432}
433
434impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
435    type Output = io::Result<server::TlsStream<IO>>;
436
437    #[inline]
438    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
439        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
440    }
441}
442
443impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
444    type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
445
446    #[inline]
447    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
448        Pin::new(&mut self.0).poll(cx)
449    }
450}
451
452impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
453    type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
454
455    #[inline]
456    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
457        Pin::new(&mut self.0).poll(cx)
458    }
459}
460
461/// Unified TLS stream type
462///
463/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
464/// a single type to keep both client- and server-initiated TLS-encrypted connections.
465#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798
466#[derive(Debug)]
467pub enum TlsStream<T> {
468    Client(client::TlsStream<T>),
469    Server(server::TlsStream<T>),
470}
471
472impl<T> TlsStream<T> {
473    pub fn get_ref(&self) -> (&T, &CommonState) {
474        use TlsStream::*;
475        match self {
476            Client(io) => {
477                let (io, session) = io.get_ref();
478                (io, session)
479            }
480            Server(io) => {
481                let (io, session) = io.get_ref();
482                (io, session)
483            }
484        }
485    }
486
487    pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
488        use TlsStream::*;
489        match self {
490            Client(io) => {
491                let (io, session) = io.get_mut();
492                (io, &mut *session)
493            }
494            Server(io) => {
495                let (io, session) = io.get_mut();
496                (io, &mut *session)
497            }
498        }
499    }
500}
501
502impl<T> From<client::TlsStream<T>> for TlsStream<T> {
503    fn from(s: client::TlsStream<T>) -> Self {
504        Self::Client(s)
505    }
506}
507
508impl<T> From<server::TlsStream<T>> for TlsStream<T> {
509    fn from(s: server::TlsStream<T>) -> Self {
510        Self::Server(s)
511    }
512}
513
514#[cfg(unix)]
515impl<S> AsRawFd for TlsStream<S>
516where
517    S: AsRawFd,
518{
519    fn as_raw_fd(&self) -> RawFd {
520        self.get_ref().0.as_raw_fd()
521    }
522}
523
524#[cfg(windows)]
525impl<S> AsRawSocket for TlsStream<S>
526where
527    S: AsRawSocket,
528{
529    fn as_raw_socket(&self) -> RawSocket {
530        self.get_ref().0.as_raw_socket()
531    }
532}
533
534impl<T> AsyncRead for TlsStream<T>
535where
536    T: AsyncRead + AsyncWrite + Unpin,
537{
538    #[inline]
539    fn poll_read(
540        self: Pin<&mut Self>,
541        cx: &mut Context<'_>,
542        buf: &mut ReadBuf<'_>,
543    ) -> Poll<io::Result<()>> {
544        match self.get_mut() {
545            TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
546            TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
547        }
548    }
549}
550
551impl<T> AsyncWrite for TlsStream<T>
552where
553    T: AsyncRead + AsyncWrite + Unpin,
554{
555    #[inline]
556    fn poll_write(
557        self: Pin<&mut Self>,
558        cx: &mut Context<'_>,
559        buf: &[u8],
560    ) -> Poll<io::Result<usize>> {
561        match self.get_mut() {
562            TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
563            TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
564        }
565    }
566
567    #[inline]
568    fn poll_write_vectored(
569        self: Pin<&mut Self>,
570        cx: &mut Context<'_>,
571        bufs: &[io::IoSlice<'_>],
572    ) -> Poll<io::Result<usize>> {
573        match self.get_mut() {
574            TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
575            TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
576        }
577    }
578
579    #[inline]
580    fn is_write_vectored(&self) -> bool {
581        match self {
582            TlsStream::Client(x) => x.is_write_vectored(),
583            TlsStream::Server(x) => x.is_write_vectored(),
584        }
585    }
586
587    #[inline]
588    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
589        match self.get_mut() {
590            TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
591            TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
592        }
593    }
594
595    #[inline]
596    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
597        match self.get_mut() {
598            TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
599            TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
600        }
601    }
602}