1use std::future::Future;
40use std::io;
41#[cfg(unix)]
42use std::os::unix::io::{AsRawFd, RawFd};
43#[cfg(windows)]
44use std::os::windows::io::{AsRawSocket, RawSocket};
45use std::pin::Pin;
46use std::sync::Arc;
47use std::task::{Context, Poll};
48
49pub use rustls;
50use rustls::server::AcceptedAlert;
51use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
52use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
53
54macro_rules! ready {
55 ( $e:expr ) => {
56 match $e {
57 std::task::Poll::Ready(t) => t,
58 std::task::Poll::Pending => return std::task::Poll::Pending,
59 }
60 };
61}
62
63pub mod client;
64mod common;
65use common::{MidHandshake, TlsState};
66pub mod server;
67
68#[derive(Clone)]
70pub struct TlsConnector {
71 inner: Arc<ClientConfig>,
72 #[cfg(feature = "early-data")]
73 early_data: bool,
74}
75
76#[derive(Clone)]
78pub struct TlsAcceptor {
79 inner: Arc<ServerConfig>,
80}
81
82impl From<Arc<ClientConfig>> for TlsConnector {
83 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
84 TlsConnector {
85 inner,
86 #[cfg(feature = "early-data")]
87 early_data: false,
88 }
89 }
90}
91
92impl From<Arc<ServerConfig>> for TlsAcceptor {
93 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
94 TlsAcceptor { inner }
95 }
96}
97
98impl TlsConnector {
99 #[cfg(feature = "early-data")]
104 pub fn early_data(mut self, flag: bool) -> TlsConnector {
105 self.early_data = flag;
106 self
107 }
108
109 #[inline]
110 pub fn connect<IO>(&self, domain: pki_types::ServerName<'static>, stream: IO) -> Connect<IO>
111 where
112 IO: AsyncRead + AsyncWrite + Unpin,
113 {
114 self.connect_with(domain, stream, |_| ())
115 }
116
117 pub fn connect_with<IO, F>(
118 &self,
119 domain: pki_types::ServerName<'static>,
120 stream: IO,
121 f: F,
122 ) -> Connect<IO>
123 where
124 IO: AsyncRead + AsyncWrite + Unpin,
125 F: FnOnce(&mut ClientConnection),
126 {
127 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
128 Ok(session) => session,
129 Err(error) => {
130 return Connect(MidHandshake::Error {
131 io: stream,
132 error: io::Error::new(io::ErrorKind::Other, error),
135 });
136 }
137 };
138 f(&mut session);
139
140 Connect(MidHandshake::Handshaking(client::TlsStream {
141 io: stream,
142
143 #[cfg(not(feature = "early-data"))]
144 state: TlsState::Stream,
145
146 #[cfg(feature = "early-data")]
147 state: if self.early_data && session.early_data().is_some() {
148 TlsState::EarlyData(0, Vec::new())
149 } else {
150 TlsState::Stream
151 },
152
153 #[cfg(feature = "early-data")]
154 early_waker: None,
155
156 session,
157 }))
158 }
159}
160
161impl TlsAcceptor {
162 #[inline]
163 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
164 where
165 IO: AsyncRead + AsyncWrite + Unpin,
166 {
167 self.accept_with(stream, |_| ())
168 }
169
170 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
171 where
172 IO: AsyncRead + AsyncWrite + Unpin,
173 F: FnOnce(&mut ServerConnection),
174 {
175 let mut session = match ServerConnection::new(self.inner.clone()) {
176 Ok(session) => session,
177 Err(error) => {
178 return Accept(MidHandshake::Error {
179 io: stream,
180 error: io::Error::new(io::ErrorKind::Other, error),
183 });
184 }
185 };
186 f(&mut session);
187
188 Accept(MidHandshake::Handshaking(server::TlsStream {
189 session,
190 io: stream,
191 state: TlsState::Stream,
192 }))
193 }
194}
195
196pub struct LazyConfigAcceptor<IO> {
197 acceptor: rustls::server::Acceptor,
198 io: Option<IO>,
199 alert: Option<(rustls::Error, AcceptedAlert)>,
200}
201
202impl<IO> LazyConfigAcceptor<IO>
203where
204 IO: AsyncRead + AsyncWrite + Unpin,
205{
206 #[inline]
207 pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
208 Self {
209 acceptor,
210 io: Some(io),
211 alert: None,
212 }
213 }
214
215 pub fn take_io(&mut self) -> Option<IO> {
257 self.io.take()
258 }
259}
260
261impl<IO> Future for LazyConfigAcceptor<IO>
262where
263 IO: AsyncRead + AsyncWrite + Unpin,
264{
265 type Output = Result<StartHandshake<IO>, io::Error>;
266
267 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
268 let this = self.get_mut();
269 loop {
270 let io = match this.io.as_mut() {
271 Some(io) => io,
272 None => {
273 return Poll::Ready(Err(io::Error::new(
274 io::ErrorKind::Other,
275 "acceptor cannot be polled after acceptance",
276 )))
277 }
278 };
279
280 if let Some((err, mut alert)) = this.alert.take() {
281 match alert.write(&mut common::SyncWriteAdapter { io, cx }) {
282 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
283 this.alert = Some((err, alert));
284 return Poll::Pending;
285 }
286 Ok(0) | Err(_) => {
287 return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
288 }
289 Ok(_) => {
290 this.alert = Some((err, alert));
291 continue;
292 }
293 };
294 }
295
296 let mut reader = common::SyncReadAdapter { io, cx };
297 match this.acceptor.read_tls(&mut reader) {
298 Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
299 Ok(_) => {}
300 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
301 Err(e) => return Err(e).into(),
302 }
303
304 match this.acceptor.accept() {
305 Ok(Some(accepted)) => {
306 let io = this.io.take().unwrap();
307 return Poll::Ready(Ok(StartHandshake { accepted, io }));
308 }
309 Ok(None) => {}
310 Err((err, alert)) => {
311 this.alert = Some((err, alert));
312 }
313 }
314 }
315 }
316}
317
318pub struct StartHandshake<IO> {
319 accepted: rustls::server::Accepted,
320 io: IO,
321}
322
323impl<IO> StartHandshake<IO>
324where
325 IO: AsyncRead + AsyncWrite + Unpin,
326{
327 pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
328 self.accepted.client_hello()
329 }
330
331 pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
332 self.into_stream_with(config, |_| ())
333 }
334
335 pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
336 where
337 F: FnOnce(&mut ServerConnection),
338 {
339 let mut conn = match self.accepted.into_connection(config) {
340 Ok(conn) => conn,
341 Err((error, alert)) => {
342 return Accept(MidHandshake::SendAlert {
343 io: self.io,
344 alert,
345 error: io::Error::new(io::ErrorKind::InvalidData, error),
348 });
349 }
350 };
351 f(&mut conn);
352
353 Accept(MidHandshake::Handshaking(server::TlsStream {
354 session: conn,
355 io: self.io,
356 state: TlsState::Stream,
357 }))
358 }
359}
360
361pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
364
365pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
368
369pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
371
372pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
374
375impl<IO> Connect<IO> {
376 #[inline]
377 pub fn into_fallible(self) -> FallibleConnect<IO> {
378 FallibleConnect(self.0)
379 }
380
381 pub fn get_ref(&self) -> Option<&IO> {
382 match &self.0 {
383 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
384 MidHandshake::SendAlert { io, .. } => Some(io),
385 MidHandshake::Error { io, .. } => Some(io),
386 MidHandshake::End => None,
387 }
388 }
389
390 pub fn get_mut(&mut self) -> Option<&mut IO> {
391 match &mut self.0 {
392 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
393 MidHandshake::SendAlert { io, .. } => Some(io),
394 MidHandshake::Error { io, .. } => Some(io),
395 MidHandshake::End => None,
396 }
397 }
398}
399
400impl<IO> Accept<IO> {
401 #[inline]
402 pub fn into_fallible(self) -> FallibleAccept<IO> {
403 FallibleAccept(self.0)
404 }
405
406 pub fn get_ref(&self) -> Option<&IO> {
407 match &self.0 {
408 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
409 MidHandshake::SendAlert { io, .. } => Some(io),
410 MidHandshake::Error { io, .. } => Some(io),
411 MidHandshake::End => None,
412 }
413 }
414
415 pub fn get_mut(&mut self) -> Option<&mut IO> {
416 match &mut self.0 {
417 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
418 MidHandshake::SendAlert { io, .. } => Some(io),
419 MidHandshake::Error { io, .. } => Some(io),
420 MidHandshake::End => None,
421 }
422 }
423}
424
425impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
426 type Output = io::Result<client::TlsStream<IO>>;
427
428 #[inline]
429 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
430 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
431 }
432}
433
434impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
435 type Output = io::Result<server::TlsStream<IO>>;
436
437 #[inline]
438 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
439 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
440 }
441}
442
443impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
444 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
445
446 #[inline]
447 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
448 Pin::new(&mut self.0).poll(cx)
449 }
450}
451
452impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
453 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
454
455 #[inline]
456 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
457 Pin::new(&mut self.0).poll(cx)
458 }
459}
460
461#[allow(clippy::large_enum_variant)] #[derive(Debug)]
467pub enum TlsStream<T> {
468 Client(client::TlsStream<T>),
469 Server(server::TlsStream<T>),
470}
471
472impl<T> TlsStream<T> {
473 pub fn get_ref(&self) -> (&T, &CommonState) {
474 use TlsStream::*;
475 match self {
476 Client(io) => {
477 let (io, session) = io.get_ref();
478 (io, session)
479 }
480 Server(io) => {
481 let (io, session) = io.get_ref();
482 (io, session)
483 }
484 }
485 }
486
487 pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
488 use TlsStream::*;
489 match self {
490 Client(io) => {
491 let (io, session) = io.get_mut();
492 (io, &mut *session)
493 }
494 Server(io) => {
495 let (io, session) = io.get_mut();
496 (io, &mut *session)
497 }
498 }
499 }
500}
501
502impl<T> From<client::TlsStream<T>> for TlsStream<T> {
503 fn from(s: client::TlsStream<T>) -> Self {
504 Self::Client(s)
505 }
506}
507
508impl<T> From<server::TlsStream<T>> for TlsStream<T> {
509 fn from(s: server::TlsStream<T>) -> Self {
510 Self::Server(s)
511 }
512}
513
514#[cfg(unix)]
515impl<S> AsRawFd for TlsStream<S>
516where
517 S: AsRawFd,
518{
519 fn as_raw_fd(&self) -> RawFd {
520 self.get_ref().0.as_raw_fd()
521 }
522}
523
524#[cfg(windows)]
525impl<S> AsRawSocket for TlsStream<S>
526where
527 S: AsRawSocket,
528{
529 fn as_raw_socket(&self) -> RawSocket {
530 self.get_ref().0.as_raw_socket()
531 }
532}
533
534impl<T> AsyncRead for TlsStream<T>
535where
536 T: AsyncRead + AsyncWrite + Unpin,
537{
538 #[inline]
539 fn poll_read(
540 self: Pin<&mut Self>,
541 cx: &mut Context<'_>,
542 buf: &mut ReadBuf<'_>,
543 ) -> Poll<io::Result<()>> {
544 match self.get_mut() {
545 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
546 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
547 }
548 }
549}
550
551impl<T> AsyncWrite for TlsStream<T>
552where
553 T: AsyncRead + AsyncWrite + Unpin,
554{
555 #[inline]
556 fn poll_write(
557 self: Pin<&mut Self>,
558 cx: &mut Context<'_>,
559 buf: &[u8],
560 ) -> Poll<io::Result<usize>> {
561 match self.get_mut() {
562 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
563 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
564 }
565 }
566
567 #[inline]
568 fn poll_write_vectored(
569 self: Pin<&mut Self>,
570 cx: &mut Context<'_>,
571 bufs: &[io::IoSlice<'_>],
572 ) -> Poll<io::Result<usize>> {
573 match self.get_mut() {
574 TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
575 TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
576 }
577 }
578
579 #[inline]
580 fn is_write_vectored(&self) -> bool {
581 match self {
582 TlsStream::Client(x) => x.is_write_vectored(),
583 TlsStream::Server(x) => x.is_write_vectored(),
584 }
585 }
586
587 #[inline]
588 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
589 match self.get_mut() {
590 TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
591 TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
592 }
593 }
594
595 #[inline]
596 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
597 match self.get_mut() {
598 TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
599 TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
600 }
601 }
602}