1#[cfg(feature = "risky-raw-split")]
2use crate::constants::{CIPHERKEYLEN, MAXHASHLEN};
3#[cfg(feature = "hfs")]
4use crate::constants::{MAXKEMCTLEN, MAXKEMPUBLEN, MAXKEMSSLEN};
5#[cfg(feature = "hfs")]
6use crate::types::Kem;
7use crate::{
8 cipherstate::{CipherState, CipherStates},
9 constants::{MAXDHLEN, MAXMSGLEN, PSKLEN, TAGLEN},
10 error::{Error, InitStage, StateProblem},
11 params::{DhToken, HandshakeTokens, MessagePatterns, NoiseParams, Token},
12 stateless_transportstate::StatelessTransportState,
13 symmetricstate::SymmetricState,
14 transportstate::TransportState,
15 types::{Dh, Hash, Random},
16 utils::Toggle,
17};
18use std::{
19 convert::{TryFrom, TryInto},
20 fmt,
21};
22
23pub struct HandshakeState {
30 pub(crate) rng: Box<dyn Random>,
31 pub(crate) symmetricstate: SymmetricState,
32 pub(crate) cipherstates: CipherStates,
33 pub(crate) s: Toggle<Box<dyn Dh>>,
34 pub(crate) e: Toggle<Box<dyn Dh>>,
35 pub(crate) fixed_ephemeral: bool,
36 pub(crate) rs: Toggle<[u8; MAXDHLEN]>,
37 pub(crate) re: Toggle<[u8; MAXDHLEN]>,
38 pub(crate) initiator: bool,
39 pub(crate) params: NoiseParams,
40 pub(crate) psks: [Option<[u8; PSKLEN]>; 10],
41 #[cfg(feature = "hfs")]
42 pub(crate) kem: Option<Box<dyn Kem>>,
43 #[cfg(feature = "hfs")]
44 pub(crate) kem_re: Option<[u8; MAXKEMPUBLEN]>,
45 pub(crate) my_turn: bool,
46 pub(crate) message_patterns: MessagePatterns,
47 pub(crate) pattern_position: usize,
48}
49
50impl HandshakeState {
51 #[allow(clippy::too_many_arguments)]
52 pub(crate) fn new(
53 rng: Box<dyn Random>,
54 cipherstate: CipherState,
55 hasher: Box<dyn Hash>,
56 s: Toggle<Box<dyn Dh>>,
57 e: Toggle<Box<dyn Dh>>,
58 fixed_ephemeral: bool,
59 rs: Toggle<[u8; MAXDHLEN]>,
60 re: Toggle<[u8; MAXDHLEN]>,
61 initiator: bool,
62 params: NoiseParams,
63 psks: [Option<[u8; PSKLEN]>; 10],
64 prologue: &[u8],
65 cipherstates: CipherStates,
66 ) -> Result<HandshakeState, Error> {
67 if (s.is_on() && e.is_on() && s.pub_len() != e.pub_len())
68 || (s.is_on() && rs.is_on() && s.pub_len() > rs.len())
69 || (s.is_on() && re.is_on() && s.pub_len() > re.len())
70 {
71 return Err(InitStage::ValidateKeyLengths.into());
72 }
73
74 let tokens = HandshakeTokens::try_from(¶ms.handshake)?;
75
76 let mut symmetricstate = SymmetricState::new(cipherstate, hasher);
77
78 symmetricstate.initialize(¶ms.name);
79 symmetricstate.mix_hash(prologue);
80
81 let dh_len = s.pub_len();
82 if initiator {
83 for token in tokens.premsg_pattern_i {
84 symmetricstate.mix_hash(
85 match *token {
86 Token::S => &s,
87 Token::E => &e,
88 _ => unreachable!(),
89 }
90 .get()
91 .ok_or(StateProblem::MissingKeyMaterial)?
92 .pubkey(),
93 );
94 }
95 for token in tokens.premsg_pattern_r {
96 symmetricstate.mix_hash(
97 &match *token {
98 Token::S => &rs,
99 Token::E => &re,
100 _ => unreachable!(),
101 }
102 .get()
103 .ok_or(StateProblem::MissingKeyMaterial)?[..dh_len],
104 );
105 }
106 } else {
107 for token in tokens.premsg_pattern_i {
108 symmetricstate.mix_hash(
109 &match *token {
110 Token::S => &rs,
111 Token::E => &re,
112 _ => unreachable!(),
113 }
114 .get()
115 .ok_or(StateProblem::MissingKeyMaterial)?[..dh_len],
116 );
117 }
118 for token in tokens.premsg_pattern_r {
119 symmetricstate.mix_hash(
120 match *token {
121 Token::S => &s,
122 Token::E => &e,
123 _ => unreachable!(),
124 }
125 .get()
126 .ok_or(StateProblem::MissingKeyMaterial)?
127 .pubkey(),
128 );
129 }
130 }
131
132 Ok(HandshakeState {
133 rng,
134 symmetricstate,
135 cipherstates,
136 s,
137 e,
138 fixed_ephemeral,
139 rs,
140 re,
141 initiator,
142 params,
143 psks,
144 #[cfg(feature = "hfs")]
145 kem: None,
146 #[cfg(feature = "hfs")]
147 kem_re: None,
148 my_turn: initiator,
149 message_patterns: tokens.msg_patterns,
150 pattern_position: 0,
151 })
152 }
153
154 pub(crate) fn dh_len(&self) -> usize {
155 self.s.pub_len()
156 }
157
158 #[cfg(feature = "hfs")]
159 pub(crate) fn set_kem(&mut self, kem: Box<dyn Kem>) {
160 self.kem = Some(kem);
161 }
162
163 fn dh(&self, token: &DhToken) -> Result<[u8; MAXDHLEN], Error> {
164 let mut dh_out = [0u8; MAXDHLEN];
165 let (dh, key) = match (token, self.is_initiator()) {
166 (DhToken::Ee, _) => (&self.e, &self.re),
167 (DhToken::Ss, _) => (&self.s, &self.rs),
168 (DhToken::Se, true) | (DhToken::Es, false) => (&self.s, &self.re),
169 (DhToken::Es, true) | (DhToken::Se, false) => (&self.e, &self.rs),
170 };
171 if !(dh.is_on() && key.is_on()) {
172 return Err(StateProblem::MissingKeyMaterial.into());
173 }
174 dh.dh(&**key, &mut dh_out)?;
175 Ok(dh_out)
176 }
177
178 pub fn was_write_payload_encrypted(&self) -> bool {
194 self.symmetricstate.has_key()
195 }
196
197 pub fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
207 let checkpoint = self.symmetricstate.checkpoint();
208 match self._write_message(payload, message) {
209 Ok(res) => {
210 self.pattern_position += 1;
211 self.my_turn = false;
212 Ok(res)
213 },
214 Err(err) => {
215 self.symmetricstate.restore(checkpoint);
216 Err(err)
217 },
218 }
219 }
220
221 fn _write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
222 if !self.my_turn {
223 return Err(StateProblem::NotTurnToWrite.into());
224 } else if self.pattern_position >= self.message_patterns.len() {
225 return Err(StateProblem::HandshakeAlreadyFinished.into());
226 }
227
228 let mut byte_index = 0;
229 for token in self.message_patterns[self.pattern_position].iter() {
230 match token {
231 Token::E => {
232 if byte_index + self.e.pub_len() > message.len() {
233 return Err(Error::Input);
234 }
235
236 if !self.fixed_ephemeral {
237 self.e.generate(&mut *self.rng);
238 }
239 let pubkey = self.e.pubkey();
240 message[byte_index..byte_index + pubkey.len()].copy_from_slice(pubkey);
241 byte_index += pubkey.len();
242 self.symmetricstate.mix_hash(pubkey);
243 if self.params.handshake.is_psk() {
244 self.symmetricstate.mix_key(pubkey);
245 }
246 self.e.enable();
247 },
248 Token::S => {
249 if !self.s.is_on() {
250 return Err(StateProblem::MissingKeyMaterial.into());
251 } else if byte_index + self.s.pub_len() > message.len() {
252 return Err(Error::Input);
253 }
254
255 byte_index += self
256 .symmetricstate
257 .encrypt_and_mix_hash(self.s.pubkey(), &mut message[byte_index..])?;
258 },
259 Token::Psk(n) => match self.psks[*n as usize] {
260 Some(psk) => {
261 self.symmetricstate.mix_key_and_hash(&psk);
262 },
263 None => {
264 return Err(StateProblem::MissingPsk.into());
265 },
266 },
267 Token::Dh(t) => {
268 let dh_out = self.dh(t)?;
269 self.symmetricstate.mix_key(&dh_out[..self.dh_len()]);
270 },
271 #[cfg(feature = "hfs")]
272 Token::E1 => {
273 let kem = self.kem.as_mut().ok_or(Error::Input)?;
274 if kem.pub_len() > message.len() {
275 return Err(Error::Input);
276 }
277
278 kem.generate(&mut *self.rng);
279 byte_index += self
280 .symmetricstate
281 .encrypt_and_mix_hash(kem.pubkey(), &mut message[byte_index..])?;
282 },
283 #[cfg(feature = "hfs")]
284 Token::Ekem1 => {
285 let kem = self.kem.as_mut().unwrap();
286 let mut kem_output_buf = [0; MAXKEMSSLEN];
287 let mut ciphertext_buf = [0; MAXKEMCTLEN];
288
289 if kem.ciphertext_len() > message.len() {
290 return Err(Error::Input);
291 }
292
293 let kem_output = &mut kem_output_buf[..kem.shared_secret_len()];
294 let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()];
295 let pubkey = &self.kem_re.as_ref().unwrap()[..kem.pub_len()];
296 if kem.encapsulate(pubkey, kem_output, ciphertext).is_err() {
297 return Err(Error::Kem);
298 }
299
300 byte_index += self.symmetricstate.encrypt_and_mix_hash(
301 &ciphertext[..kem.ciphertext_len()],
302 &mut message[byte_index..],
303 )?;
304 self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]);
305 },
306 }
307 }
308
309 if byte_index + payload.len() + TAGLEN > message.len() {
310 return Err(Error::Input);
311 }
312 byte_index +=
313 self.symmetricstate.encrypt_and_mix_hash(payload, &mut message[byte_index..])?;
314 if byte_index > MAXMSGLEN {
315 return Err(Error::Input);
316 }
317 if self.pattern_position == (self.message_patterns.len() - 1) {
318 self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1);
319 }
320 Ok(byte_index)
321 }
322
323 pub fn read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> {
334 let checkpoint = self.symmetricstate.checkpoint();
335 match self._read_message(message, payload) {
336 Ok(res) => {
337 self.pattern_position += 1;
338 self.my_turn = true;
339 Ok(res)
340 },
341 Err(err) => {
342 self.symmetricstate.restore(checkpoint);
343 Err(err)
344 },
345 }
346 }
347
348 fn _read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> {
349 if message.len() > MAXMSGLEN {
350 return Err(Error::Input);
351 } else if self.my_turn {
352 return Err(StateProblem::NotTurnToRead.into());
353 } else if self.pattern_position >= self.message_patterns.len() {
354 return Err(StateProblem::HandshakeAlreadyFinished.into());
355 }
356 let last = self.pattern_position == (self.message_patterns.len() - 1);
357
358 let dh_len = self.dh_len();
359 let mut ptr = message;
360 for token in self.message_patterns[self.pattern_position].iter() {
361 match token {
362 Token::E => {
363 if ptr.len() < dh_len {
364 return Err(Error::Input);
365 }
366 self.re[..dh_len].copy_from_slice(&ptr[..dh_len]);
367 ptr = &ptr[dh_len..];
368 self.symmetricstate.mix_hash(&self.re[..dh_len]);
369 if self.params.handshake.is_psk() {
370 self.symmetricstate.mix_key(&self.re[..dh_len]);
371 }
372 self.re.enable();
373 },
374 Token::S => {
375 let data = if self.symmetricstate.has_key() {
376 if ptr.len() < dh_len + TAGLEN {
377 return Err(Error::Input);
378 }
379 let temp = &ptr[..dh_len + TAGLEN];
380 ptr = &ptr[dh_len + TAGLEN..];
381 temp
382 } else {
383 if ptr.len() < dh_len {
384 return Err(Error::Input);
385 }
386 let temp = &ptr[..dh_len];
387 ptr = &ptr[dh_len..];
388 temp
389 };
390 self.symmetricstate.decrypt_and_mix_hash(data, &mut self.rs[..dh_len])?;
391 self.rs.enable();
392 },
393 Token::Psk(n) => match self.psks[*n as usize] {
394 Some(psk) => {
395 self.symmetricstate.mix_key_and_hash(&psk);
396 },
397 None => {
398 return Err(StateProblem::MissingPsk.into());
399 },
400 },
401 Token::Dh(t) => {
402 let dh_out = self.dh(t)?;
403 self.symmetricstate.mix_key(&dh_out[..self.dh_len()]);
404 },
405 #[cfg(feature = "hfs")]
406 Token::E1 => {
407 let kem = self.kem.as_ref().ok_or(Error::Kem)?;
408 let read_len = if self.symmetricstate.has_key() {
409 kem.pub_len() + TAGLEN
410 } else {
411 kem.pub_len()
412 };
413 if ptr.len() < read_len {
414 return Err(Error::Input);
415 }
416 let mut kem_re = [0; MAXKEMPUBLEN];
417 self.symmetricstate
418 .decrypt_and_mix_hash(&ptr[..read_len], &mut kem_re[..kem.pub_len()])?;
419 self.kem_re = Some(kem_re);
420 ptr = &ptr[read_len..];
421 },
422 #[cfg(feature = "hfs")]
423 Token::Ekem1 => {
424 let kem = self.kem.as_ref().unwrap();
425 let read_len = if self.symmetricstate.has_key() {
426 kem.ciphertext_len() + TAGLEN
427 } else {
428 kem.ciphertext_len()
429 };
430 if ptr.len() < read_len {
431 return Err(Error::Input);
432 }
433 let mut ciphertext_buf = [0; MAXKEMCTLEN];
434 let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()];
435 self.symmetricstate.decrypt_and_mix_hash(&ptr[..read_len], ciphertext)?;
436 let mut kem_output_buf = [0; MAXKEMSSLEN];
437 let kem_output = &mut kem_output_buf[..kem.shared_secret_len()];
438 kem.decapsulate(ciphertext, kem_output).map_err(|_| Error::Kem)?;
439 self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]);
440 ptr = &ptr[read_len..];
441 },
442 }
443 }
444
445 self.symmetricstate.decrypt_and_mix_hash(ptr, payload)?;
446 if last {
447 self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1);
448 }
449 let payload_len =
450 if self.symmetricstate.has_key() { ptr.len() - TAGLEN } else { ptr.len() };
451 Ok(payload_len)
452 }
453
454 pub fn set_psk(&mut self, location: usize, key: &[u8]) -> Result<(), Error> {
462 if key.len() != PSKLEN || self.psks.len() <= location {
463 return Err(Error::Input);
464 }
465
466 let mut new_psk = [0u8; PSKLEN];
467 new_psk.copy_from_slice(key);
468 self.psks[location as usize] = Some(new_psk);
469
470 Ok(())
471 }
472
473 pub fn get_remote_static(&self) -> Option<&[u8]> {
480 self.rs.get().map(|rs| &rs[..self.dh_len()])
481 }
482
483 pub fn get_handshake_hash(&self) -> &[u8] {
487 self.symmetricstate.handshake_hash()
488 }
489
490 pub fn is_initiator(&self) -> bool {
492 self.initiator
493 }
494
495 pub fn is_handshake_finished(&self) -> bool {
497 self.pattern_position == self.message_patterns.len()
498 }
499
500 pub fn is_my_turn(&self) -> bool {
502 self.my_turn
503 }
504
505 #[cfg(feature = "risky-raw-split")]
510 pub fn dangerously_get_raw_split(&mut self) -> ([u8; CIPHERKEYLEN], [u8; CIPHERKEYLEN]) {
511 let mut output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
512 self.symmetricstate.split_raw(&mut output.0, &mut output.1);
513 (output.0[..CIPHERKEYLEN].try_into().unwrap(), output.1[..CIPHERKEYLEN].try_into().unwrap())
514 }
515
516 pub fn into_transport_mode(self) -> Result<TransportState, Error> {
518 self.try_into()
519 }
520
521 pub fn into_stateless_transport_mode(self) -> Result<StatelessTransportState, Error> {
523 self.try_into()
524 }
525}
526
527impl fmt::Debug for HandshakeState {
528 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
529 fmt.debug_struct("HandshakeState").finish()
530 }
531}