libp2p_websocket/
tls.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use futures_rustls::{rustls, TlsAcceptor, TlsConnector};
22use std::convert::TryFrom;
23use std::{fmt, io, sync::Arc};
24
25/// TLS configuration.
26#[derive(Clone)]
27pub struct Config {
28    pub(crate) client: TlsConnector,
29    pub(crate) server: Option<TlsAcceptor>,
30}
31
32impl fmt::Debug for Config {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.write_str("Config")
35    }
36}
37
38/// Private key, DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
39#[derive(Clone)]
40pub struct PrivateKey(rustls::PrivateKey);
41
42impl PrivateKey {
43    /// Assert the given bytes are DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
44    pub fn new(bytes: Vec<u8>) -> Self {
45        PrivateKey(rustls::PrivateKey(bytes))
46    }
47}
48
49/// Certificate, DER-encoded X.509 format.
50#[derive(Debug, Clone)]
51pub struct Certificate(rustls::Certificate);
52
53impl Certificate {
54    /// Assert the given bytes are in DER-encoded X.509 format.
55    pub fn new(bytes: Vec<u8>) -> Self {
56        Certificate(rustls::Certificate(bytes))
57    }
58}
59
60impl Config {
61    /// Create a new TLS configuration with the given server key and certificate chain.
62    pub fn new<I>(key: PrivateKey, certs: I) -> Result<Self, Error>
63    where
64        I: IntoIterator<Item = Certificate>,
65    {
66        let mut builder = Config::builder();
67        builder.server(key, certs)?;
68        Ok(builder.finish())
69    }
70
71    /// Create a client-only configuration.
72    pub fn client() -> Self {
73        let client = rustls::ClientConfig::builder()
74            .with_safe_defaults()
75            .with_root_certificates(client_root_store())
76            .with_no_client_auth();
77        Config {
78            client: Arc::new(client).into(),
79            server: None,
80        }
81    }
82
83    /// Create a new TLS configuration builder.
84    pub fn builder() -> Builder {
85        Builder {
86            client_root_store: client_root_store(),
87            server: None,
88        }
89    }
90}
91
92/// Setup the rustls client configuration.
93fn client_root_store() -> rustls::RootCertStore {
94    let mut client_root_store = rustls::RootCertStore::empty();
95    client_root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
96        rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
97            ta.subject,
98            ta.spki,
99            ta.name_constraints,
100        )
101    }));
102    client_root_store
103}
104
105/// TLS configuration builder.
106pub struct Builder {
107    client_root_store: rustls::RootCertStore,
108    server: Option<rustls::ServerConfig>,
109}
110
111impl Builder {
112    /// Set server key and certificate chain.
113    pub fn server<I>(&mut self, key: PrivateKey, certs: I) -> Result<&mut Self, Error>
114    where
115        I: IntoIterator<Item = Certificate>,
116    {
117        let certs = certs.into_iter().map(|c| c.0).collect();
118        let server = rustls::ServerConfig::builder()
119            .with_safe_defaults()
120            .with_no_client_auth()
121            .with_single_cert(certs, key.0)
122            .map_err(|e| Error::Tls(Box::new(e)))?;
123        self.server = Some(server);
124        Ok(self)
125    }
126
127    /// Add an additional trust anchor.
128    pub fn add_trust(&mut self, cert: &Certificate) -> Result<&mut Self, Error> {
129        self.client_root_store
130            .add(&cert.0)
131            .map_err(|e| Error::Tls(Box::new(e)))?;
132        Ok(self)
133    }
134
135    /// Finish configuration.
136    pub fn finish(self) -> Config {
137        let client = rustls::ClientConfig::builder()
138            .with_safe_defaults()
139            .with_root_certificates(self.client_root_store)
140            .with_no_client_auth();
141
142        Config {
143            client: Arc::new(client).into(),
144            server: self.server.map(|s| Arc::new(s).into()),
145        }
146    }
147}
148
149pub(crate) fn dns_name_ref(name: &str) -> Result<rustls::ServerName, Error> {
150    rustls::ServerName::try_from(name).map_err(|_| Error::InvalidDnsName(name.into()))
151}
152
153// Error //////////////////////////////////////////////////////////////////////////////////////////
154
155/// TLS related errors.
156#[derive(Debug)]
157#[non_exhaustive]
158pub enum Error {
159    /// An underlying I/O error.
160    Io(io::Error),
161    /// Actual TLS error.
162    Tls(Box<dyn std::error::Error + Send + Sync>),
163    /// The DNS name was invalid.
164    InvalidDnsName(String),
165}
166
167impl fmt::Display for Error {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        match self {
170            Error::Io(e) => write!(f, "i/o error: {e}"),
171            Error::Tls(e) => write!(f, "tls error: {e}"),
172            Error::InvalidDnsName(n) => write!(f, "invalid DNS name: {n}"),
173        }
174    }
175}
176
177impl std::error::Error for Error {
178    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
179        match self {
180            Error::Io(e) => Some(e),
181            Error::Tls(e) => Some(&**e),
182            Error::InvalidDnsName(_) => None,
183        }
184    }
185}
186
187impl From<io::Error> for Error {
188    fn from(e: io::Error) -> Self {
189        Error::Io(e)
190    }
191}