1#[cfg(feature = "hfs")]
2use crate::params::HandshakeModifier;
3use crate::{
4 cipherstate::{CipherState, CipherStates},
5 constants::{MAXDHLEN, PSKLEN},
6 error::{Error, InitStage, Prerequisite},
7 handshakestate::HandshakeState,
8 params::NoiseParams,
9 resolvers::{BoxedCryptoResolver, CryptoResolver},
10 utils::Toggle,
11};
12use subtle::ConstantTimeEq;
13
14pub struct Keypair {
18 pub private: Vec<u8>,
20 pub public: Vec<u8>,
22}
23
24impl PartialEq for Keypair {
25 fn eq(&self, other: &Keypair) -> bool {
26 let priv_eq = self.private.ct_eq(&other.private);
27 let pub_eq = self.public.ct_eq(&other.public);
28
29 (priv_eq & pub_eq).into()
30 }
31}
32
33pub struct Builder<'builder> {
51 params: NoiseParams,
52 resolver: BoxedCryptoResolver,
53 s: Option<&'builder [u8]>,
54 e_fixed: Option<&'builder [u8]>,
55 rs: Option<&'builder [u8]>,
56 psks: [Option<&'builder [u8]>; 10],
57 plog: Option<&'builder [u8]>,
58}
59
60impl<'builder> Builder<'builder> {
61 #[cfg(all(
63 feature = "default-resolver",
64 not(any(feature = "ring-accelerated", feature = "libsodium-accelerated"))
65 ))]
66 pub fn new(params: NoiseParams) -> Self {
67 use crate::resolvers::DefaultResolver;
68
69 Self::with_resolver(params, Box::new(DefaultResolver::default()))
70 }
71
72 #[cfg(all(not(feature = "libsodium-accelerated"), feature = "ring-accelerated"))]
74 pub fn new(params: NoiseParams) -> Self {
75 use crate::resolvers::{DefaultResolver, FallbackResolver, RingResolver};
76
77 Self::with_resolver(
78 params,
79 Box::new(FallbackResolver::new(Box::new(RingResolver), Box::new(DefaultResolver))),
80 )
81 }
82
83 #[cfg(all(not(feature = "ring-accelerated"), feature = "libsodium-accelerated"))]
85 pub fn new(params: NoiseParams) -> Self {
86 use crate::resolvers::{DefaultResolver, FallbackResolver, SodiumResolver};
87
88 Self::with_resolver(
89 params,
90 Box::new(FallbackResolver::new(Box::new(SodiumResolver), Box::new(DefaultResolver))),
91 )
92 }
93
94 pub fn with_resolver(params: NoiseParams, resolver: BoxedCryptoResolver) -> Self {
96 Builder { params, resolver, s: None, e_fixed: None, rs: None, plog: None, psks: [None; 10] }
97 }
98
99 pub fn psk(mut self, location: u8, key: &'builder [u8]) -> Self {
106 self.psks[location as usize] = Some(key);
107 self
108 }
109
110 pub fn local_private_key(mut self, key: &'builder [u8]) -> Self {
119 self.s = Some(key);
120 self
121 }
122
123 #[doc(hidden)]
124 pub fn fixed_ephemeral_key_for_testing_only(mut self, key: &'builder [u8]) -> Self {
125 self.e_fixed = Some(key);
126 self
127 }
128
129 pub fn prologue(mut self, key: &'builder [u8]) -> Self {
136 self.plog = Some(key);
137 self
138 }
139
140 pub fn remote_public_key(mut self, pub_key: &'builder [u8]) -> Self {
147 self.rs = Some(pub_key);
148 self
149 }
150
151 pub fn generate_keypair(&self) -> Result<Keypair, Error> {
154 let mut rng = self.resolver.resolve_rng().ok_or(InitStage::GetRngImpl)?;
155 let mut dh = self.resolver.resolve_dh(&self.params.dh).ok_or(InitStage::GetDhImpl)?;
156 let mut private = vec![0u8; dh.priv_len()];
157 let mut public = vec![0u8; dh.pub_len()];
158 dh.generate(&mut *rng);
159
160 private.copy_from_slice(dh.privkey());
161 public.copy_from_slice(dh.pubkey());
162
163 Ok(Keypair { private, public })
164 }
165
166 pub fn build_initiator(self) -> Result<HandshakeState, Error> {
168 self.build(true)
169 }
170
171 pub fn build_responder(self) -> Result<HandshakeState, Error> {
173 self.build(false)
174 }
175
176 fn build(self, initiator: bool) -> Result<HandshakeState, Error> {
177 if self.s.is_none() && self.params.handshake.pattern.needs_local_static_key(initiator) {
178 return Err(Prerequisite::LocalPrivateKey.into());
179 }
180
181 if self.rs.is_none() && self.params.handshake.pattern.need_known_remote_pubkey(initiator) {
182 return Err(Prerequisite::RemotePublicKey.into());
183 }
184
185 let rng = self.resolver.resolve_rng().ok_or(InitStage::GetRngImpl)?;
186 let cipher =
187 self.resolver.resolve_cipher(&self.params.cipher).ok_or(InitStage::GetCipherImpl)?;
188 let hash = self.resolver.resolve_hash(&self.params.hash).ok_or(InitStage::GetHashImpl)?;
189 let mut s_dh = self.resolver.resolve_dh(&self.params.dh).ok_or(InitStage::GetDhImpl)?;
190 let mut e_dh = self.resolver.resolve_dh(&self.params.dh).ok_or(InitStage::GetDhImpl)?;
191 let cipher1 =
192 self.resolver.resolve_cipher(&self.params.cipher).ok_or(InitStage::GetCipherImpl)?;
193 let cipher2 =
194 self.resolver.resolve_cipher(&self.params.cipher).ok_or(InitStage::GetCipherImpl)?;
195 let handshake_cipherstate = CipherState::new(cipher);
196 let cipherstates = CipherStates::new(CipherState::new(cipher1), CipherState::new(cipher2))?;
197
198 let s = match self.s {
199 Some(k) => {
200 (*s_dh).set(k);
201 Toggle::on(s_dh)
202 },
203 None => Toggle::off(s_dh),
204 };
205
206 if let Some(fixed_k) = self.e_fixed {
207 (*e_dh).set(fixed_k);
208 }
209 let e = Toggle::off(e_dh);
210
211 let mut rs_buf = [0u8; MAXDHLEN];
212 let rs = match self.rs {
213 Some(v) => {
214 rs_buf[..v.len()].copy_from_slice(v);
215 Toggle::on(rs_buf)
216 },
217 None => Toggle::off(rs_buf),
218 };
219
220 let re = Toggle::off([0u8; MAXDHLEN]);
221
222 let mut psks = [None::<[u8; PSKLEN]>; 10];
223 for (i, psk) in self.psks.iter().enumerate() {
224 if let Some(key) = *psk {
225 if key.len() != PSKLEN {
226 return Err(InitStage::ValidatePskLengths.into());
227 }
228 let mut k = [0u8; PSKLEN];
229 k.copy_from_slice(key);
230 psks[i] = Some(k);
231 }
232 }
233
234 let mut hs = HandshakeState::new(
235 rng,
236 handshake_cipherstate,
237 hash,
238 s,
239 e,
240 self.e_fixed.is_some(),
241 rs,
242 re,
243 initiator,
244 self.params,
245 psks,
246 self.plog.unwrap_or(&[]),
247 cipherstates,
248 )?;
249 Self::resolve_kem(self.resolver, &mut hs)?;
250 Ok(hs)
251 }
252
253 #[cfg(not(feature = "hfs"))]
254 fn resolve_kem(_: Box<dyn CryptoResolver>, _: &mut HandshakeState) -> Result<(), Error> {
255 Ok(())
257 }
258
259 #[cfg(feature = "hfs")]
260 fn resolve_kem(
261 resolver: Box<dyn CryptoResolver>,
262 hs: &mut HandshakeState,
263 ) -> Result<(), Error> {
264 if hs.params.handshake.modifiers.list.contains(&HandshakeModifier::Hfs) {
265 if let Some(kem_choice) = hs.params.kem {
266 let kem = resolver.resolve_kem(&kem_choice).ok_or(InitStage::GetKemImpl)?;
267 hs.set_kem(kem);
268 } else {
269 return Err(InitStage::GetKemImpl.into());
270 }
271 }
272 Ok(())
273 }
274}
275
276#[cfg(test)]
277#[cfg(any(feature = "default-resolver", feature = "ring-accelerated"))]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_builder() {
283 let _noise = Builder::new("Noise_NN_25519_ChaChaPoly_SHA256".parse().unwrap())
284 .prologue(&[2, 2, 2, 2, 2, 2, 2, 2])
285 .local_private_key(&[0u8; 32])
286 .build_initiator()
287 .unwrap();
288 }
289
290 #[test]
291 fn test_builder_keygen() {
292 let builder = Builder::new("Noise_NN_25519_ChaChaPoly_SHA256".parse().unwrap());
293 let key1 = builder.generate_keypair();
294 let key2 = builder.generate_keypair();
295 assert!(key1.unwrap() != key2.unwrap());
296 }
297
298 #[test]
299 fn test_builder_bad_spec() {
300 let params: ::std::result::Result<NoiseParams, _> =
301 "Noise_NK_25519_ChaChaPoly_BLAH256".parse();
302
303 if params.is_ok() {
304 panic!("NoiseParams should have failed");
305 }
306 }
307
308 #[test]
309 fn test_builder_missing_prereqs() {
310 let noise = Builder::new("Noise_NK_25519_ChaChaPoly_SHA256".parse().unwrap())
311 .prologue(&[2, 2, 2, 2, 2, 2, 2, 2])
312 .local_private_key(&[0u8; 32])
313 .build_initiator(); if noise.is_ok() {
316 panic!("builder should have failed on build");
317 }
318 }
319
320 #[test]
321 fn test_partialeq_impl() {
322 let keypair_1 = Keypair { private: vec![0x01; 32], public: vec![0x01; 32] };
323
324 let mut keypair_2 = Keypair { private: vec![0x01; 32], public: vec![0x01; 32] };
325
326 assert_eq!(keypair_1 == keypair_2, true);
328
329 keypair_2.private = vec![0x50; 32];
333 assert_eq!(keypair_1 == keypair_2, false);
334 keypair_2.private = vec![0x01; 32];
336 keypair_2.public = vec![0x50; 32];
338 assert_eq!(keypair_1 == keypair_2, false);
339 }
340}