tokio_tungstenite/
tls.rs

1//! Connection helper.
2use tokio::io::{AsyncRead, AsyncWrite};
3
4use tungstenite::{
5    client::uri_mode, error::Error, handshake::client::Response, protocol::WebSocketConfig,
6};
7
8use crate::{client_async_with_config, IntoClientRequest, WebSocketStream};
9
10pub use crate::stream::MaybeTlsStream;
11
12/// A connector that can be used when establishing connections, allowing to control whether
13/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
14/// `Plain` variant.
15#[non_exhaustive]
16#[derive(Clone)]
17pub enum Connector {
18    /// Plain (non-TLS) connector.
19    Plain,
20    /// `native-tls` TLS connector.
21    #[cfg(feature = "native-tls")]
22    NativeTls(native_tls_crate::TlsConnector),
23    /// `rustls` TLS connector.
24    #[cfg(feature = "__rustls-tls")]
25    Rustls(std::sync::Arc<rustls::ClientConfig>),
26}
27
28mod encryption {
29    #[cfg(feature = "native-tls")]
30    pub mod native_tls {
31        use native_tls_crate::TlsConnector;
32        use tokio_native_tls::TlsConnector as TokioTlsConnector;
33
34        use tokio::io::{AsyncRead, AsyncWrite};
35
36        use tungstenite::{error::TlsError, stream::Mode, Error};
37
38        use crate::stream::MaybeTlsStream;
39
40        pub async fn wrap_stream<S>(
41            socket: S,
42            domain: String,
43            mode: Mode,
44            tls_connector: Option<TlsConnector>,
45        ) -> Result<MaybeTlsStream<S>, Error>
46        where
47            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
48        {
49            match mode {
50                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
51                Mode::Tls => {
52                    let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
53                    let connector = try_connector.map_err(TlsError::Native)?;
54                    let stream = TokioTlsConnector::from(connector);
55                    let connected = stream.connect(&domain, socket).await;
56                    match connected {
57                        Err(e) => Err(Error::Tls(e.into())),
58                        Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
59                    }
60                }
61            }
62        }
63    }
64
65    #[cfg(feature = "__rustls-tls")]
66    pub mod rustls {
67        pub use rustls::ClientConfig;
68        use rustls::{RootCertStore, ServerName};
69        use tokio_rustls::TlsConnector as TokioTlsConnector;
70
71        use std::{convert::TryFrom, sync::Arc};
72        use tokio::io::{AsyncRead, AsyncWrite};
73
74        use tungstenite::{error::TlsError, stream::Mode, Error};
75
76        use crate::stream::MaybeTlsStream;
77
78        pub async fn wrap_stream<S>(
79            socket: S,
80            domain: String,
81            mode: Mode,
82            tls_connector: Option<Arc<ClientConfig>>,
83        ) -> Result<MaybeTlsStream<S>, Error>
84        where
85            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
86        {
87            match mode {
88                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
89                Mode::Tls => {
90                    let config = match tls_connector {
91                        Some(config) => config,
92                        None => {
93                            #[allow(unused_mut)]
94                            let mut root_store = RootCertStore::empty();
95                            #[cfg(feature = "rustls-tls-native-roots")]
96                            {
97                                let native_certs = rustls_native_certs::load_native_certs()?;
98                                let der_certs: Vec<Vec<u8>> =
99                                    native_certs.into_iter().map(|cert| cert.0).collect();
100                                let total_number = der_certs.len();
101                                let (number_added, number_ignored) =
102                                    root_store.add_parsable_certificates(&der_certs);
103                                log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
104                            }
105                            #[cfg(feature = "rustls-tls-webpki-roots")]
106                            {
107                                root_store.add_trust_anchors(
108                                    webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
109                                        rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
110                                            ta.subject,
111                                            ta.spki,
112                                            ta.name_constraints,
113                                        )
114                                    })
115                                );
116                            }
117
118                            Arc::new(
119                                ClientConfig::builder()
120                                    .with_safe_defaults()
121                                    .with_root_certificates(root_store)
122                                    .with_no_client_auth(),
123                            )
124                        }
125                    };
126                    let domain = ServerName::try_from(domain.as_str())
127                        .map_err(|_| TlsError::InvalidDnsName)?;
128                    let stream = TokioTlsConnector::from(config);
129                    let connected = stream.connect(domain, socket).await;
130
131                    match connected {
132                        Err(e) => Err(Error::Io(e)),
133                        Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
134                    }
135                }
136            }
137        }
138    }
139
140    pub mod plain {
141        use tokio::io::{AsyncRead, AsyncWrite};
142
143        use tungstenite::{
144            error::{Error, UrlError},
145            stream::Mode,
146        };
147
148        use crate::stream::MaybeTlsStream;
149
150        pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
151        where
152            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
153        {
154            match mode {
155                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
156                Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
157            }
158        }
159    }
160}
161
162/// Creates a WebSocket handshake from a request and a stream,
163/// upgrading the stream to TLS if required.
164#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
165pub async fn client_async_tls<R, S>(
166    request: R,
167    stream: S,
168) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
169where
170    R: IntoClientRequest + Unpin,
171    S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
172    MaybeTlsStream<S>: Unpin,
173{
174    client_async_tls_with_config(request, stream, None, None).await
175}
176
177/// The same as `client_async_tls()` but the one can specify a websocket configuration,
178/// and an optional connector. If no connector is specified, a default one will
179/// be created.
180///
181/// Please refer to `client_async_tls()` for more details.
182pub async fn client_async_tls_with_config<R, S>(
183    request: R,
184    stream: S,
185    config: Option<WebSocketConfig>,
186    connector: Option<Connector>,
187) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
188where
189    R: IntoClientRequest + Unpin,
190    S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
191    MaybeTlsStream<S>: Unpin,
192{
193    let request = request.into_client_request()?;
194
195    #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
196    let domain = crate::domain(&request)?;
197
198    // Make sure we check domain and mode first. URL must be valid.
199    let mode = uri_mode(request.uri())?;
200
201    let stream = match connector {
202        Some(conn) => match conn {
203            #[cfg(feature = "native-tls")]
204            Connector::NativeTls(conn) => {
205                self::encryption::native_tls::wrap_stream(stream, domain, mode, Some(conn)).await
206            }
207            #[cfg(feature = "__rustls-tls")]
208            Connector::Rustls(conn) => {
209                self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
210            }
211            Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
212        },
213        None => {
214            #[cfg(feature = "native-tls")]
215            {
216                self::encryption::native_tls::wrap_stream(stream, domain, mode, None).await
217            }
218            #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
219            {
220                self::encryption::rustls::wrap_stream(stream, domain, mode, None).await
221            }
222            #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
223            {
224                self::encryption::plain::wrap_stream(stream, mode).await
225            }
226        }
227    }?;
228
229    client_async_with_config(request, stream, config).await
230}