snow/
builder.rs

1#[cfg(feature = "hfs")]
2use crate::params::HandshakeModifier;
3use crate::{
4    cipherstate::{CipherState, CipherStates},
5    constants::{MAXDHLEN, PSKLEN},
6    error::{Error, InitStage, Prerequisite},
7    handshakestate::HandshakeState,
8    params::NoiseParams,
9    resolvers::{BoxedCryptoResolver, CryptoResolver},
10    utils::Toggle,
11};
12use subtle::ConstantTimeEq;
13
14/// A keypair object returned by [`Builder::generate_keypair()`]
15///
16/// [`generate_keypair()`]: #method.generate_keypair
17pub struct Keypair {
18    /// The private asymmetric key
19    pub private: Vec<u8>,
20    /// The public asymmetric key
21    pub public:  Vec<u8>,
22}
23
24impl PartialEq for Keypair {
25    fn eq(&self, other: &Keypair) -> bool {
26        let priv_eq = self.private.ct_eq(&other.private);
27        let pub_eq = self.public.ct_eq(&other.public);
28
29        (priv_eq & pub_eq).into()
30    }
31}
32
33/// Generates a [`HandshakeState`] and also validates that all the prerequisites for
34/// the given parameters are satisfied.
35///
36/// # Examples
37///
38/// ```
39/// # use snow::Builder;
40/// # let my_long_term_key = [0u8; 32];
41/// # let their_pub_key = [0u8; 32];
42/// # #[cfg(any(feature = "default-resolver", feature = "ring-accelerated"))]
43/// let noise = Builder::new("Noise_XX_25519_ChaChaPoly_BLAKE2s".parse().unwrap())
44///     .local_private_key(&my_long_term_key)
45///     .remote_public_key(&their_pub_key)
46///     .prologue("noise is just swell".as_bytes())
47///     .build_initiator()
48///     .unwrap();
49/// ```
50pub struct Builder<'builder> {
51    params:   NoiseParams,
52    resolver: BoxedCryptoResolver,
53    s:        Option<&'builder [u8]>,
54    e_fixed:  Option<&'builder [u8]>,
55    rs:       Option<&'builder [u8]>,
56    psks:     [Option<&'builder [u8]>; 10],
57    plog:     Option<&'builder [u8]>,
58}
59
60impl<'builder> Builder<'builder> {
61    /// Create a Builder with the default crypto resolver.
62    #[cfg(all(
63        feature = "default-resolver",
64        not(any(feature = "ring-accelerated", feature = "libsodium-accelerated"))
65    ))]
66    pub fn new(params: NoiseParams) -> Self {
67        use crate::resolvers::DefaultResolver;
68
69        Self::with_resolver(params, Box::new(DefaultResolver::default()))
70    }
71
72    /// Create a Builder with the ring resolver and default resolver as a fallback.
73    #[cfg(all(not(feature = "libsodium-accelerated"), feature = "ring-accelerated"))]
74    pub fn new(params: NoiseParams) -> Self {
75        use crate::resolvers::{DefaultResolver, FallbackResolver, RingResolver};
76
77        Self::with_resolver(
78            params,
79            Box::new(FallbackResolver::new(Box::new(RingResolver), Box::new(DefaultResolver))),
80        )
81    }
82
83    /// Create a Builder with the ring resolver and default resolver as a fallback.
84    #[cfg(all(not(feature = "ring-accelerated"), feature = "libsodium-accelerated"))]
85    pub fn new(params: NoiseParams) -> Self {
86        use crate::resolvers::{DefaultResolver, FallbackResolver, SodiumResolver};
87
88        Self::with_resolver(
89            params,
90            Box::new(FallbackResolver::new(Box::new(SodiumResolver), Box::new(DefaultResolver))),
91        )
92    }
93
94    /// Create a Builder with a custom crypto resolver.
95    pub fn with_resolver(params: NoiseParams, resolver: BoxedCryptoResolver) -> Self {
96        Builder { params, resolver, s: None, e_fixed: None, rs: None, plog: None, psks: [None; 10] }
97    }
98
99    /// Specify a PSK (only used with `NoisePSK` base parameter)
100    ///
101    /// # Safety
102    /// This will overwrite the value provided in any previous call to this method. Please take care
103    /// to ensure this is not a security risk. In future versions, multiple calls to the same
104    /// builder method will be explicitly prohibited.
105    pub fn psk(mut self, location: u8, key: &'builder [u8]) -> Self {
106        self.psks[location as usize] = Some(key);
107        self
108    }
109
110    /// Your static private key (can be generated with [`generate_keypair()`]).
111    ///
112    /// # Safety
113    /// This will overwrite the value provided in any previous call to this method. Please take care
114    /// to ensure this is not a security risk. In future versions, multiple calls to the same
115    /// builder method will be explicitly prohibited.
116    ///
117    /// [`generate_keypair()`]: #method.generate_keypair
118    pub fn local_private_key(mut self, key: &'builder [u8]) -> Self {
119        self.s = Some(key);
120        self
121    }
122
123    #[doc(hidden)]
124    pub fn fixed_ephemeral_key_for_testing_only(mut self, key: &'builder [u8]) -> Self {
125        self.e_fixed = Some(key);
126        self
127    }
128
129    /// Arbitrary data to be hashed in to the handshake hash value.
130    ///
131    /// # Safety
132    /// This will overwrite the value provided in any previous call to this method. Please take care
133    /// to ensure this is not a security risk. In future versions, multiple calls to the same
134    /// builder method will be explicitly prohibited.
135    pub fn prologue(mut self, key: &'builder [u8]) -> Self {
136        self.plog = Some(key);
137        self
138    }
139
140    /// The responder's static public key.
141    ///
142    /// # Safety
143    /// This will overwrite the value provided in any previous call to this method. Please take care
144    /// to ensure this is not a security risk. In future versions, multiple calls to the same
145    /// builder method will be explicitly prohibited.
146    pub fn remote_public_key(mut self, pub_key: &'builder [u8]) -> Self {
147        self.rs = Some(pub_key);
148        self
149    }
150
151    // TODO: performance issue w/ creating a new RNG and DH instance per call.
152    /// Generate a new asymmetric keypair (for use as a static key).
153    pub fn generate_keypair(&self) -> Result<Keypair, Error> {
154        let mut rng = self.resolver.resolve_rng().ok_or(InitStage::GetRngImpl)?;
155        let mut dh = self.resolver.resolve_dh(&self.params.dh).ok_or(InitStage::GetDhImpl)?;
156        let mut private = vec![0u8; dh.priv_len()];
157        let mut public = vec![0u8; dh.pub_len()];
158        dh.generate(&mut *rng);
159
160        private.copy_from_slice(dh.privkey());
161        public.copy_from_slice(dh.pubkey());
162
163        Ok(Keypair { private, public })
164    }
165
166    /// Build a [`HandshakeState`] for the side who will initiate the handshake (send the first message)
167    pub fn build_initiator(self) -> Result<HandshakeState, Error> {
168        self.build(true)
169    }
170
171    /// Build a [`HandshakeState`] for the side who will be responder (receive the first message)
172    pub fn build_responder(self) -> Result<HandshakeState, Error> {
173        self.build(false)
174    }
175
176    fn build(self, initiator: bool) -> Result<HandshakeState, Error> {
177        if self.s.is_none() && self.params.handshake.pattern.needs_local_static_key(initiator) {
178            return Err(Prerequisite::LocalPrivateKey.into());
179        }
180
181        if self.rs.is_none() && self.params.handshake.pattern.need_known_remote_pubkey(initiator) {
182            return Err(Prerequisite::RemotePublicKey.into());
183        }
184
185        let rng = self.resolver.resolve_rng().ok_or(InitStage::GetRngImpl)?;
186        let cipher =
187            self.resolver.resolve_cipher(&self.params.cipher).ok_or(InitStage::GetCipherImpl)?;
188        let hash = self.resolver.resolve_hash(&self.params.hash).ok_or(InitStage::GetHashImpl)?;
189        let mut s_dh = self.resolver.resolve_dh(&self.params.dh).ok_or(InitStage::GetDhImpl)?;
190        let mut e_dh = self.resolver.resolve_dh(&self.params.dh).ok_or(InitStage::GetDhImpl)?;
191        let cipher1 =
192            self.resolver.resolve_cipher(&self.params.cipher).ok_or(InitStage::GetCipherImpl)?;
193        let cipher2 =
194            self.resolver.resolve_cipher(&self.params.cipher).ok_or(InitStage::GetCipherImpl)?;
195        let handshake_cipherstate = CipherState::new(cipher);
196        let cipherstates = CipherStates::new(CipherState::new(cipher1), CipherState::new(cipher2))?;
197
198        let s = match self.s {
199            Some(k) => {
200                (*s_dh).set(k);
201                Toggle::on(s_dh)
202            },
203            None => Toggle::off(s_dh),
204        };
205
206        if let Some(fixed_k) = self.e_fixed {
207            (*e_dh).set(fixed_k);
208        }
209        let e = Toggle::off(e_dh);
210
211        let mut rs_buf = [0u8; MAXDHLEN];
212        let rs = match self.rs {
213            Some(v) => {
214                rs_buf[..v.len()].copy_from_slice(v);
215                Toggle::on(rs_buf)
216            },
217            None => Toggle::off(rs_buf),
218        };
219
220        let re = Toggle::off([0u8; MAXDHLEN]);
221
222        let mut psks = [None::<[u8; PSKLEN]>; 10];
223        for (i, psk) in self.psks.iter().enumerate() {
224            if let Some(key) = *psk {
225                if key.len() != PSKLEN {
226                    return Err(InitStage::ValidatePskLengths.into());
227                }
228                let mut k = [0u8; PSKLEN];
229                k.copy_from_slice(key);
230                psks[i] = Some(k);
231            }
232        }
233
234        let mut hs = HandshakeState::new(
235            rng,
236            handshake_cipherstate,
237            hash,
238            s,
239            e,
240            self.e_fixed.is_some(),
241            rs,
242            re,
243            initiator,
244            self.params,
245            psks,
246            self.plog.unwrap_or(&[]),
247            cipherstates,
248        )?;
249        Self::resolve_kem(self.resolver, &mut hs)?;
250        Ok(hs)
251    }
252
253    #[cfg(not(feature = "hfs"))]
254    fn resolve_kem(_: Box<dyn CryptoResolver>, _: &mut HandshakeState) -> Result<(), Error> {
255        // HFS is disabled, return nothing
256        Ok(())
257    }
258
259    #[cfg(feature = "hfs")]
260    fn resolve_kem(
261        resolver: Box<dyn CryptoResolver>,
262        hs: &mut HandshakeState,
263    ) -> Result<(), Error> {
264        if hs.params.handshake.modifiers.list.contains(&HandshakeModifier::Hfs) {
265            if let Some(kem_choice) = hs.params.kem {
266                let kem = resolver.resolve_kem(&kem_choice).ok_or(InitStage::GetKemImpl)?;
267                hs.set_kem(kem);
268            } else {
269                return Err(InitStage::GetKemImpl.into());
270            }
271        }
272        Ok(())
273    }
274}
275
276#[cfg(test)]
277#[cfg(any(feature = "default-resolver", feature = "ring-accelerated"))]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_builder() {
283        let _noise = Builder::new("Noise_NN_25519_ChaChaPoly_SHA256".parse().unwrap())
284            .prologue(&[2, 2, 2, 2, 2, 2, 2, 2])
285            .local_private_key(&[0u8; 32])
286            .build_initiator()
287            .unwrap();
288    }
289
290    #[test]
291    fn test_builder_keygen() {
292        let builder = Builder::new("Noise_NN_25519_ChaChaPoly_SHA256".parse().unwrap());
293        let key1 = builder.generate_keypair();
294        let key2 = builder.generate_keypair();
295        assert!(key1.unwrap() != key2.unwrap());
296    }
297
298    #[test]
299    fn test_builder_bad_spec() {
300        let params: ::std::result::Result<NoiseParams, _> =
301            "Noise_NK_25519_ChaChaPoly_BLAH256".parse();
302
303        if params.is_ok() {
304            panic!("NoiseParams should have failed");
305        }
306    }
307
308    #[test]
309    fn test_builder_missing_prereqs() {
310        let noise = Builder::new("Noise_NK_25519_ChaChaPoly_SHA256".parse().unwrap())
311            .prologue(&[2, 2, 2, 2, 2, 2, 2, 2])
312            .local_private_key(&[0u8; 32])
313            .build_initiator(); // missing remote key, should result in Err
314
315        if noise.is_ok() {
316            panic!("builder should have failed on build");
317        }
318    }
319
320    #[test]
321    fn test_partialeq_impl() {
322        let keypair_1 = Keypair { private: vec![0x01; 32], public: vec![0x01; 32] };
323
324        let mut keypair_2 = Keypair { private: vec![0x01; 32], public: vec![0x01; 32] };
325
326        // If both private and public are the same, return true
327        assert_eq!(keypair_1 == keypair_2, true);
328
329        // If either public or private are different, return false
330
331        // Wrong private
332        keypair_2.private = vec![0x50; 32];
333        assert_eq!(keypair_1 == keypair_2, false);
334        // Reset to original
335        keypair_2.private = vec![0x01; 32];
336        // Wrong public
337        keypair_2.public = vec![0x50; 32];
338        assert_eq!(keypair_1 == keypair_2, false);
339    }
340}