futures_rustls/
lib.rs

1//! Asynchronous TLS/SSL streams for futures using [Rustls](https://github.com/ctz/rustls).
2
3macro_rules! ready {
4    ( $e:expr ) => {
5        match $e {
6            std::task::Poll::Ready(t) => t,
7            std::task::Poll::Pending => return std::task::Poll::Pending,
8        }
9    };
10}
11
12pub mod client;
13mod common;
14pub mod server;
15
16use common::{MidHandshake, Stream, TlsState};
17use futures_io::{AsyncRead, AsyncWrite};
18use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
19use std::future::Future;
20use std::io;
21#[cfg(unix)]
22use std::os::unix::io::{AsRawFd, RawFd};
23#[cfg(windows)]
24use std::os::windows::io::{AsRawSocket, RawSocket};
25use std::pin::Pin;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29pub use rustls;
30
31/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
32#[derive(Clone)]
33pub struct TlsConnector {
34    inner: Arc<ClientConfig>,
35    #[cfg(feature = "early-data")]
36    early_data: bool,
37}
38
39/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
40#[derive(Clone)]
41pub struct TlsAcceptor {
42    inner: Arc<ServerConfig>,
43}
44
45impl From<Arc<ClientConfig>> for TlsConnector {
46    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
47        TlsConnector {
48            inner,
49            #[cfg(feature = "early-data")]
50            early_data: false,
51        }
52    }
53}
54
55impl From<Arc<ServerConfig>> for TlsAcceptor {
56    fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
57        TlsAcceptor { inner }
58    }
59}
60
61impl TlsConnector {
62    /// Enable 0-RTT.
63    ///
64    /// If you want to use 0-RTT,
65    /// You must also set `ClientConfig.enable_early_data` to `true`.
66    #[cfg(feature = "early-data")]
67    pub fn early_data(mut self, flag: bool) -> TlsConnector {
68        self.early_data = flag;
69        self
70    }
71
72    #[inline]
73    pub fn connect<IO>(&self, domain: rustls::ServerName, stream: IO) -> Connect<IO>
74    where
75        IO: AsyncRead + AsyncWrite + Unpin,
76    {
77        self.connect_with(domain, stream, |_| ())
78    }
79
80    pub fn connect_with<IO, F>(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect<IO>
81    where
82        IO: AsyncRead + AsyncWrite + Unpin,
83        F: FnOnce(&mut ClientConnection),
84    {
85        let mut session = match ClientConnection::new(self.inner.clone(), domain) {
86            Ok(session) => session,
87            Err(error) => {
88                return Connect(MidHandshake::Error {
89                    io: stream,
90                    // TODO(eliza): should this really return an `io::Error`?
91                    // Probably not...
92                    error: io::Error::new(io::ErrorKind::Other, error),
93                });
94            }
95        };
96        f(&mut session);
97
98        Connect(MidHandshake::Handshaking(client::TlsStream {
99            io: stream,
100
101            #[cfg(not(feature = "early-data"))]
102            state: TlsState::Stream,
103
104            #[cfg(feature = "early-data")]
105            state: if self.early_data && session.early_data().is_some() {
106                TlsState::EarlyData(0, Vec::new())
107            } else {
108                TlsState::Stream
109            },
110
111            #[cfg(feature = "early-data")]
112            early_waker: None,
113
114            session,
115        }))
116    }
117}
118
119impl TlsAcceptor {
120    #[inline]
121    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
122    where
123        IO: AsyncRead + AsyncWrite + Unpin,
124    {
125        self.accept_with(stream, |_| ())
126    }
127
128    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
129    where
130        IO: AsyncRead + AsyncWrite + Unpin,
131        F: FnOnce(&mut ServerConnection),
132    {
133        let mut session = match ServerConnection::new(self.inner.clone()) {
134            Ok(session) => session,
135            Err(error) => {
136                return Accept(MidHandshake::Error {
137                    io: stream,
138                    // TODO(eliza): should this really return an `io::Error`?
139                    // Probably not...
140                    error: io::Error::new(io::ErrorKind::Other, error),
141                });
142            }
143        };
144        f(&mut session);
145
146        Accept(MidHandshake::Handshaking(server::TlsStream {
147            session,
148            io: stream,
149            state: TlsState::Stream,
150        }))
151    }
152}
153
154pub struct LazyConfigAcceptor<IO> {
155    acceptor: rustls::server::Acceptor,
156    io: Option<IO>,
157}
158
159impl<IO> LazyConfigAcceptor<IO>
160where
161    IO: AsyncRead + AsyncWrite + Unpin,
162{
163    #[inline]
164    pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
165        Self {
166            acceptor,
167            io: Some(io),
168        }
169    }
170}
171
172impl<IO> Future for LazyConfigAcceptor<IO>
173where
174    IO: AsyncRead + AsyncWrite + Unpin,
175{
176    type Output = Result<StartHandshake<IO>, io::Error>;
177
178    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        let this = self.get_mut();
180        loop {
181            let io = match this.io.as_mut() {
182                Some(io) => io,
183                None => {
184                    return Poll::Ready(Err(io::Error::new(
185                        io::ErrorKind::Other,
186                        "acceptor cannot be polled after acceptance",
187                    )))
188                }
189            };
190
191            let mut reader = common::SyncReadAdapter { io, cx };
192            match this.acceptor.read_tls(&mut reader) {
193                Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
194                Ok(_) => {}
195                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
196                Err(e) => return Err(e).into(),
197            }
198
199            match this.acceptor.accept() {
200                Ok(Some(accepted)) => {
201                    let io = this.io.take().unwrap();
202                    return Poll::Ready(Ok(StartHandshake { accepted, io }));
203                }
204                Ok(None) => continue,
205                Err(err) => {
206                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
207                }
208            }
209        }
210    }
211}
212
213pub struct StartHandshake<IO> {
214    accepted: rustls::server::Accepted,
215    io: IO,
216}
217
218impl<IO> StartHandshake<IO>
219where
220    IO: AsyncRead + AsyncWrite + Unpin,
221{
222    pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
223        self.accepted.client_hello()
224    }
225
226    pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
227        self.into_stream_with(config, |_| ())
228    }
229
230    pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
231    where
232        F: FnOnce(&mut ServerConnection),
233    {
234        let mut conn = match self.accepted.into_connection(config) {
235            Ok(conn) => conn,
236            Err(error) => {
237                return Accept(MidHandshake::Error {
238                    io: self.io,
239                    // TODO(eliza): should this really return an `io::Error`?
240                    // Probably not...
241                    error: io::Error::new(io::ErrorKind::Other, error),
242                });
243            }
244        };
245        f(&mut conn);
246
247        Accept(MidHandshake::Handshaking(server::TlsStream {
248            session: conn,
249            io: self.io,
250            state: TlsState::Stream,
251        }))
252    }
253}
254
255/// Future returned from `TlsConnector::connect` which will resolve
256/// once the connection handshake has finished.
257pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
258
259/// Future returned from `TlsAcceptor::accept` which will resolve
260/// once the accept handshake has finished.
261pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
262
263/// Like [Connect], but returns `IO` on failure.
264pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
265
266/// Like [Accept], but returns `IO` on failure.
267pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
268
269impl<IO> Connect<IO> {
270    #[inline]
271    pub fn into_fallible(self) -> FallibleConnect<IO> {
272        FallibleConnect(self.0)
273    }
274}
275
276impl<IO> Accept<IO> {
277    #[inline]
278    pub fn into_fallible(self) -> FallibleAccept<IO> {
279        FallibleAccept(self.0)
280    }
281}
282
283impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
284    type Output = io::Result<client::TlsStream<IO>>;
285
286    #[inline]
287    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
288        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
289    }
290}
291
292impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
293    type Output = io::Result<server::TlsStream<IO>>;
294
295    #[inline]
296    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
297        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
298    }
299}
300
301impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
302    type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
303
304    #[inline]
305    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
306        Pin::new(&mut self.0).poll(cx)
307    }
308}
309
310impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
311    type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
312
313    #[inline]
314    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
315        Pin::new(&mut self.0).poll(cx)
316    }
317}
318
319/// Unified TLS stream type
320///
321/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
322/// a single type to keep both client- and server-initiated TLS-encrypted connections.
323#[derive(Debug)]
324pub enum TlsStream<T> {
325    Client(client::TlsStream<T>),
326    Server(server::TlsStream<T>),
327}
328
329impl<T> TlsStream<T> {
330    pub fn get_ref(&self) -> (&T, &CommonState) {
331        use TlsStream::*;
332        match self {
333            Client(io) => {
334                let (io, session) = io.get_ref();
335                (io, &*session)
336            }
337            Server(io) => {
338                let (io, session) = io.get_ref();
339                (io, &*session)
340            }
341        }
342    }
343
344    pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
345        use TlsStream::*;
346        match self {
347            Client(io) => {
348                let (io, session) = io.get_mut();
349                (io, &mut *session)
350            }
351            Server(io) => {
352                let (io, session) = io.get_mut();
353                (io, &mut *session)
354            }
355        }
356    }
357}
358
359impl<T> From<client::TlsStream<T>> for TlsStream<T> {
360    fn from(s: client::TlsStream<T>) -> Self {
361        Self::Client(s)
362    }
363}
364
365impl<T> From<server::TlsStream<T>> for TlsStream<T> {
366    fn from(s: server::TlsStream<T>) -> Self {
367        Self::Server(s)
368    }
369}
370
371#[cfg(unix)]
372impl<S> AsRawFd for TlsStream<S>
373where
374    S: AsRawFd,
375{
376    fn as_raw_fd(&self) -> RawFd {
377        self.get_ref().0.as_raw_fd()
378    }
379}
380
381#[cfg(windows)]
382impl<S> AsRawSocket for TlsStream<S>
383where
384    S: AsRawSocket,
385{
386    fn as_raw_socket(&self) -> RawSocket {
387        self.get_ref().0.as_raw_socket()
388    }
389}
390
391impl<T> AsyncRead for TlsStream<T>
392where
393    T: AsyncRead + AsyncWrite + Unpin,
394{
395    #[inline]
396    fn poll_read(
397        self: Pin<&mut Self>,
398        cx: &mut Context<'_>,
399        buf: &mut [u8],
400    ) -> Poll<io::Result<usize>> {
401        match self.get_mut() {
402            TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
403            TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
404        }
405    }
406}
407
408impl<T> AsyncWrite for TlsStream<T>
409where
410    T: AsyncRead + AsyncWrite + Unpin,
411{
412    #[inline]
413    fn poll_write(
414        self: Pin<&mut Self>,
415        cx: &mut Context<'_>,
416        buf: &[u8],
417    ) -> Poll<io::Result<usize>> {
418        match self.get_mut() {
419            TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
420            TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
421        }
422    }
423
424    #[inline]
425    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
426        match self.get_mut() {
427            TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
428            TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
429        }
430    }
431
432    #[inline]
433    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
434        match self.get_mut() {
435            TlsStream::Client(x) => Pin::new(x).poll_close(cx),
436            TlsStream::Server(x) => Pin::new(x).poll_close(cx),
437        }
438    }
439}