1macro_rules! ready {
4 ( $e:expr ) => {
5 match $e {
6 std::task::Poll::Ready(t) => t,
7 std::task::Poll::Pending => return std::task::Poll::Pending,
8 }
9 };
10}
11
12pub mod client;
13mod common;
14pub mod server;
15
16use common::{MidHandshake, Stream, TlsState};
17use futures_io::{AsyncRead, AsyncWrite};
18use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
19use std::future::Future;
20use std::io;
21#[cfg(unix)]
22use std::os::unix::io::{AsRawFd, RawFd};
23#[cfg(windows)]
24use std::os::windows::io::{AsRawSocket, RawSocket};
25use std::pin::Pin;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29pub use rustls;
30
31#[derive(Clone)]
33pub struct TlsConnector {
34 inner: Arc<ClientConfig>,
35 #[cfg(feature = "early-data")]
36 early_data: bool,
37}
38
39#[derive(Clone)]
41pub struct TlsAcceptor {
42 inner: Arc<ServerConfig>,
43}
44
45impl From<Arc<ClientConfig>> for TlsConnector {
46 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
47 TlsConnector {
48 inner,
49 #[cfg(feature = "early-data")]
50 early_data: false,
51 }
52 }
53}
54
55impl From<Arc<ServerConfig>> for TlsAcceptor {
56 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
57 TlsAcceptor { inner }
58 }
59}
60
61impl TlsConnector {
62 #[cfg(feature = "early-data")]
67 pub fn early_data(mut self, flag: bool) -> TlsConnector {
68 self.early_data = flag;
69 self
70 }
71
72 #[inline]
73 pub fn connect<IO>(&self, domain: rustls::ServerName, stream: IO) -> Connect<IO>
74 where
75 IO: AsyncRead + AsyncWrite + Unpin,
76 {
77 self.connect_with(domain, stream, |_| ())
78 }
79
80 pub fn connect_with<IO, F>(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect<IO>
81 where
82 IO: AsyncRead + AsyncWrite + Unpin,
83 F: FnOnce(&mut ClientConnection),
84 {
85 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
86 Ok(session) => session,
87 Err(error) => {
88 return Connect(MidHandshake::Error {
89 io: stream,
90 error: io::Error::new(io::ErrorKind::Other, error),
93 });
94 }
95 };
96 f(&mut session);
97
98 Connect(MidHandshake::Handshaking(client::TlsStream {
99 io: stream,
100
101 #[cfg(not(feature = "early-data"))]
102 state: TlsState::Stream,
103
104 #[cfg(feature = "early-data")]
105 state: if self.early_data && session.early_data().is_some() {
106 TlsState::EarlyData(0, Vec::new())
107 } else {
108 TlsState::Stream
109 },
110
111 #[cfg(feature = "early-data")]
112 early_waker: None,
113
114 session,
115 }))
116 }
117}
118
119impl TlsAcceptor {
120 #[inline]
121 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
122 where
123 IO: AsyncRead + AsyncWrite + Unpin,
124 {
125 self.accept_with(stream, |_| ())
126 }
127
128 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
129 where
130 IO: AsyncRead + AsyncWrite + Unpin,
131 F: FnOnce(&mut ServerConnection),
132 {
133 let mut session = match ServerConnection::new(self.inner.clone()) {
134 Ok(session) => session,
135 Err(error) => {
136 return Accept(MidHandshake::Error {
137 io: stream,
138 error: io::Error::new(io::ErrorKind::Other, error),
141 });
142 }
143 };
144 f(&mut session);
145
146 Accept(MidHandshake::Handshaking(server::TlsStream {
147 session,
148 io: stream,
149 state: TlsState::Stream,
150 }))
151 }
152}
153
154pub struct LazyConfigAcceptor<IO> {
155 acceptor: rustls::server::Acceptor,
156 io: Option<IO>,
157}
158
159impl<IO> LazyConfigAcceptor<IO>
160where
161 IO: AsyncRead + AsyncWrite + Unpin,
162{
163 #[inline]
164 pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
165 Self {
166 acceptor,
167 io: Some(io),
168 }
169 }
170}
171
172impl<IO> Future for LazyConfigAcceptor<IO>
173where
174 IO: AsyncRead + AsyncWrite + Unpin,
175{
176 type Output = Result<StartHandshake<IO>, io::Error>;
177
178 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179 let this = self.get_mut();
180 loop {
181 let io = match this.io.as_mut() {
182 Some(io) => io,
183 None => {
184 return Poll::Ready(Err(io::Error::new(
185 io::ErrorKind::Other,
186 "acceptor cannot be polled after acceptance",
187 )))
188 }
189 };
190
191 let mut reader = common::SyncReadAdapter { io, cx };
192 match this.acceptor.read_tls(&mut reader) {
193 Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
194 Ok(_) => {}
195 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
196 Err(e) => return Err(e).into(),
197 }
198
199 match this.acceptor.accept() {
200 Ok(Some(accepted)) => {
201 let io = this.io.take().unwrap();
202 return Poll::Ready(Ok(StartHandshake { accepted, io }));
203 }
204 Ok(None) => continue,
205 Err(err) => {
206 return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
207 }
208 }
209 }
210 }
211}
212
213pub struct StartHandshake<IO> {
214 accepted: rustls::server::Accepted,
215 io: IO,
216}
217
218impl<IO> StartHandshake<IO>
219where
220 IO: AsyncRead + AsyncWrite + Unpin,
221{
222 pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
223 self.accepted.client_hello()
224 }
225
226 pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
227 self.into_stream_with(config, |_| ())
228 }
229
230 pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
231 where
232 F: FnOnce(&mut ServerConnection),
233 {
234 let mut conn = match self.accepted.into_connection(config) {
235 Ok(conn) => conn,
236 Err(error) => {
237 return Accept(MidHandshake::Error {
238 io: self.io,
239 error: io::Error::new(io::ErrorKind::Other, error),
242 });
243 }
244 };
245 f(&mut conn);
246
247 Accept(MidHandshake::Handshaking(server::TlsStream {
248 session: conn,
249 io: self.io,
250 state: TlsState::Stream,
251 }))
252 }
253}
254
255pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
258
259pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
262
263pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
265
266pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
268
269impl<IO> Connect<IO> {
270 #[inline]
271 pub fn into_fallible(self) -> FallibleConnect<IO> {
272 FallibleConnect(self.0)
273 }
274}
275
276impl<IO> Accept<IO> {
277 #[inline]
278 pub fn into_fallible(self) -> FallibleAccept<IO> {
279 FallibleAccept(self.0)
280 }
281}
282
283impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
284 type Output = io::Result<client::TlsStream<IO>>;
285
286 #[inline]
287 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
288 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
289 }
290}
291
292impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
293 type Output = io::Result<server::TlsStream<IO>>;
294
295 #[inline]
296 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
297 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
298 }
299}
300
301impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
302 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
303
304 #[inline]
305 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
306 Pin::new(&mut self.0).poll(cx)
307 }
308}
309
310impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
311 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
312
313 #[inline]
314 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
315 Pin::new(&mut self.0).poll(cx)
316 }
317}
318
319#[derive(Debug)]
324pub enum TlsStream<T> {
325 Client(client::TlsStream<T>),
326 Server(server::TlsStream<T>),
327}
328
329impl<T> TlsStream<T> {
330 pub fn get_ref(&self) -> (&T, &CommonState) {
331 use TlsStream::*;
332 match self {
333 Client(io) => {
334 let (io, session) = io.get_ref();
335 (io, &*session)
336 }
337 Server(io) => {
338 let (io, session) = io.get_ref();
339 (io, &*session)
340 }
341 }
342 }
343
344 pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
345 use TlsStream::*;
346 match self {
347 Client(io) => {
348 let (io, session) = io.get_mut();
349 (io, &mut *session)
350 }
351 Server(io) => {
352 let (io, session) = io.get_mut();
353 (io, &mut *session)
354 }
355 }
356 }
357}
358
359impl<T> From<client::TlsStream<T>> for TlsStream<T> {
360 fn from(s: client::TlsStream<T>) -> Self {
361 Self::Client(s)
362 }
363}
364
365impl<T> From<server::TlsStream<T>> for TlsStream<T> {
366 fn from(s: server::TlsStream<T>) -> Self {
367 Self::Server(s)
368 }
369}
370
371#[cfg(unix)]
372impl<S> AsRawFd for TlsStream<S>
373where
374 S: AsRawFd,
375{
376 fn as_raw_fd(&self) -> RawFd {
377 self.get_ref().0.as_raw_fd()
378 }
379}
380
381#[cfg(windows)]
382impl<S> AsRawSocket for TlsStream<S>
383where
384 S: AsRawSocket,
385{
386 fn as_raw_socket(&self) -> RawSocket {
387 self.get_ref().0.as_raw_socket()
388 }
389}
390
391impl<T> AsyncRead for TlsStream<T>
392where
393 T: AsyncRead + AsyncWrite + Unpin,
394{
395 #[inline]
396 fn poll_read(
397 self: Pin<&mut Self>,
398 cx: &mut Context<'_>,
399 buf: &mut [u8],
400 ) -> Poll<io::Result<usize>> {
401 match self.get_mut() {
402 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
403 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
404 }
405 }
406}
407
408impl<T> AsyncWrite for TlsStream<T>
409where
410 T: AsyncRead + AsyncWrite + Unpin,
411{
412 #[inline]
413 fn poll_write(
414 self: Pin<&mut Self>,
415 cx: &mut Context<'_>,
416 buf: &[u8],
417 ) -> Poll<io::Result<usize>> {
418 match self.get_mut() {
419 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
420 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
421 }
422 }
423
424 #[inline]
425 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
426 match self.get_mut() {
427 TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
428 TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
429 }
430 }
431
432 #[inline]
433 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
434 match self.get_mut() {
435 TlsStream::Client(x) => Pin::new(x).poll_close(cx),
436 TlsStream::Server(x) => Pin::new(x).poll_close(cx),
437 }
438 }
439}