snow/
handshakestate.rs

1#[cfg(feature = "risky-raw-split")]
2use crate::constants::{CIPHERKEYLEN, MAXHASHLEN};
3#[cfg(feature = "hfs")]
4use crate::constants::{MAXKEMCTLEN, MAXKEMPUBLEN, MAXKEMSSLEN};
5#[cfg(feature = "hfs")]
6use crate::types::Kem;
7use crate::{
8    cipherstate::{CipherState, CipherStates},
9    constants::{MAXDHLEN, MAXMSGLEN, PSKLEN, TAGLEN},
10    error::{Error, InitStage, StateProblem},
11    params::{DhToken, HandshakeTokens, MessagePatterns, NoiseParams, Token},
12    stateless_transportstate::StatelessTransportState,
13    symmetricstate::SymmetricState,
14    transportstate::TransportState,
15    types::{Dh, Hash, Random},
16    utils::Toggle,
17};
18use std::{
19    convert::{TryFrom, TryInto},
20    fmt,
21};
22
23/// A state machine encompassing the handshake phase of a Noise session.
24///
25/// **Note:** you are probably looking for [`Builder`](struct.Builder.html) to
26/// get started.
27///
28/// See: https://noiseprotocol.org/noise.html#the-handshakestate-object
29pub struct HandshakeState {
30    pub(crate) rng:              Box<dyn Random>,
31    pub(crate) symmetricstate:   SymmetricState,
32    pub(crate) cipherstates:     CipherStates,
33    pub(crate) s:                Toggle<Box<dyn Dh>>,
34    pub(crate) e:                Toggle<Box<dyn Dh>>,
35    pub(crate) fixed_ephemeral:  bool,
36    pub(crate) rs:               Toggle<[u8; MAXDHLEN]>,
37    pub(crate) re:               Toggle<[u8; MAXDHLEN]>,
38    pub(crate) initiator:        bool,
39    pub(crate) params:           NoiseParams,
40    pub(crate) psks:             [Option<[u8; PSKLEN]>; 10],
41    #[cfg(feature = "hfs")]
42    pub(crate) kem:              Option<Box<dyn Kem>>,
43    #[cfg(feature = "hfs")]
44    pub(crate) kem_re:           Option<[u8; MAXKEMPUBLEN]>,
45    pub(crate) my_turn:          bool,
46    pub(crate) message_patterns: MessagePatterns,
47    pub(crate) pattern_position: usize,
48}
49
50impl HandshakeState {
51    #[allow(clippy::too_many_arguments)]
52    pub(crate) fn new(
53        rng: Box<dyn Random>,
54        cipherstate: CipherState,
55        hasher: Box<dyn Hash>,
56        s: Toggle<Box<dyn Dh>>,
57        e: Toggle<Box<dyn Dh>>,
58        fixed_ephemeral: bool,
59        rs: Toggle<[u8; MAXDHLEN]>,
60        re: Toggle<[u8; MAXDHLEN]>,
61        initiator: bool,
62        params: NoiseParams,
63        psks: [Option<[u8; PSKLEN]>; 10],
64        prologue: &[u8],
65        cipherstates: CipherStates,
66    ) -> Result<HandshakeState, Error> {
67        if (s.is_on() && e.is_on() && s.pub_len() != e.pub_len())
68            || (s.is_on() && rs.is_on() && s.pub_len() > rs.len())
69            || (s.is_on() && re.is_on() && s.pub_len() > re.len())
70        {
71            return Err(InitStage::ValidateKeyLengths.into());
72        }
73
74        let tokens = HandshakeTokens::try_from(&params.handshake)?;
75
76        let mut symmetricstate = SymmetricState::new(cipherstate, hasher);
77
78        symmetricstate.initialize(&params.name);
79        symmetricstate.mix_hash(prologue);
80
81        let dh_len = s.pub_len();
82        if initiator {
83            for token in tokens.premsg_pattern_i {
84                symmetricstate.mix_hash(
85                    match *token {
86                        Token::S => &s,
87                        Token::E => &e,
88                        _ => unreachable!(),
89                    }
90                    .get()
91                    .ok_or(StateProblem::MissingKeyMaterial)?
92                    .pubkey(),
93                );
94            }
95            for token in tokens.premsg_pattern_r {
96                symmetricstate.mix_hash(
97                    &match *token {
98                        Token::S => &rs,
99                        Token::E => &re,
100                        _ => unreachable!(),
101                    }
102                    .get()
103                    .ok_or(StateProblem::MissingKeyMaterial)?[..dh_len],
104                );
105            }
106        } else {
107            for token in tokens.premsg_pattern_i {
108                symmetricstate.mix_hash(
109                    &match *token {
110                        Token::S => &rs,
111                        Token::E => &re,
112                        _ => unreachable!(),
113                    }
114                    .get()
115                    .ok_or(StateProblem::MissingKeyMaterial)?[..dh_len],
116                );
117            }
118            for token in tokens.premsg_pattern_r {
119                symmetricstate.mix_hash(
120                    match *token {
121                        Token::S => &s,
122                        Token::E => &e,
123                        _ => unreachable!(),
124                    }
125                    .get()
126                    .ok_or(StateProblem::MissingKeyMaterial)?
127                    .pubkey(),
128                );
129            }
130        }
131
132        Ok(HandshakeState {
133            rng,
134            symmetricstate,
135            cipherstates,
136            s,
137            e,
138            fixed_ephemeral,
139            rs,
140            re,
141            initiator,
142            params,
143            psks,
144            #[cfg(feature = "hfs")]
145            kem: None,
146            #[cfg(feature = "hfs")]
147            kem_re: None,
148            my_turn: initiator,
149            message_patterns: tokens.msg_patterns,
150            pattern_position: 0,
151        })
152    }
153
154    pub(crate) fn dh_len(&self) -> usize {
155        self.s.pub_len()
156    }
157
158    #[cfg(feature = "hfs")]
159    pub(crate) fn set_kem(&mut self, kem: Box<dyn Kem>) {
160        self.kem = Some(kem);
161    }
162
163    fn dh(&self, token: &DhToken) -> Result<[u8; MAXDHLEN], Error> {
164        let mut dh_out = [0u8; MAXDHLEN];
165        let (dh, key) = match (token, self.is_initiator()) {
166            (DhToken::Ee, _) => (&self.e, &self.re),
167            (DhToken::Ss, _) => (&self.s, &self.rs),
168            (DhToken::Se, true) | (DhToken::Es, false) => (&self.s, &self.re),
169            (DhToken::Es, true) | (DhToken::Se, false) => (&self.e, &self.rs),
170        };
171        if !(dh.is_on() && key.is_on()) {
172            return Err(StateProblem::MissingKeyMaterial.into());
173        }
174        dh.dh(&**key, &mut dh_out)?;
175        Ok(dh_out)
176    }
177
178    /// This method will return `true` if the *previous* write payload was encrypted.
179    ///
180    /// See [Payload Security Properties](https://noiseprotocol.org/noise.html#payload-security-properties)
181    /// for more information on the specific properties of your chosen handshake pattern.
182    ///
183    /// # Examples
184    ///
185    /// ```rust,ignore
186    /// let mut session = Builder::new("Noise_NN_25519_AESGCM_SHA256".parse()?)
187    ///     .build_initiator()?;
188    ///
189    /// // write message...
190    ///
191    /// assert!(session.was_write_payload_encrypted());
192    /// ```
193    pub fn was_write_payload_encrypted(&self) -> bool {
194        self.symmetricstate.has_key()
195    }
196
197    /// Construct a message from `payload` (and pending handshake tokens if in handshake state),
198    /// and writes it to the `message` buffer.
199    ///
200    /// Returns the size of the written payload.
201    ///
202    /// # Errors
203    ///
204    /// Will result in `Error::Input` if the size of the output exceeds the max message
205    /// length in the Noise Protocol (65535 bytes).
206    pub fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
207        let checkpoint = self.symmetricstate.checkpoint();
208        match self._write_message(payload, message) {
209            Ok(res) => {
210                self.pattern_position += 1;
211                self.my_turn = false;
212                Ok(res)
213            },
214            Err(err) => {
215                self.symmetricstate.restore(checkpoint);
216                Err(err)
217            },
218        }
219    }
220
221    fn _write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
222        if !self.my_turn {
223            return Err(StateProblem::NotTurnToWrite.into());
224        } else if self.pattern_position >= self.message_patterns.len() {
225            return Err(StateProblem::HandshakeAlreadyFinished.into());
226        }
227
228        let mut byte_index = 0;
229        for token in self.message_patterns[self.pattern_position].iter() {
230            match token {
231                Token::E => {
232                    if byte_index + self.e.pub_len() > message.len() {
233                        return Err(Error::Input);
234                    }
235
236                    if !self.fixed_ephemeral {
237                        self.e.generate(&mut *self.rng);
238                    }
239                    let pubkey = self.e.pubkey();
240                    message[byte_index..byte_index + pubkey.len()].copy_from_slice(pubkey);
241                    byte_index += pubkey.len();
242                    self.symmetricstate.mix_hash(pubkey);
243                    if self.params.handshake.is_psk() {
244                        self.symmetricstate.mix_key(pubkey);
245                    }
246                    self.e.enable();
247                },
248                Token::S => {
249                    if !self.s.is_on() {
250                        return Err(StateProblem::MissingKeyMaterial.into());
251                    } else if byte_index + self.s.pub_len() > message.len() {
252                        return Err(Error::Input);
253                    }
254
255                    byte_index += self
256                        .symmetricstate
257                        .encrypt_and_mix_hash(self.s.pubkey(), &mut message[byte_index..])?;
258                },
259                Token::Psk(n) => match self.psks[*n as usize] {
260                    Some(psk) => {
261                        self.symmetricstate.mix_key_and_hash(&psk);
262                    },
263                    None => {
264                        return Err(StateProblem::MissingPsk.into());
265                    },
266                },
267                Token::Dh(t) => {
268                    let dh_out = self.dh(t)?;
269                    self.symmetricstate.mix_key(&dh_out[..self.dh_len()]);
270                },
271                #[cfg(feature = "hfs")]
272                Token::E1 => {
273                    let kem = self.kem.as_mut().ok_or(Error::Input)?;
274                    if kem.pub_len() > message.len() {
275                        return Err(Error::Input);
276                    }
277
278                    kem.generate(&mut *self.rng);
279                    byte_index += self
280                        .symmetricstate
281                        .encrypt_and_mix_hash(kem.pubkey(), &mut message[byte_index..])?;
282                },
283                #[cfg(feature = "hfs")]
284                Token::Ekem1 => {
285                    let kem = self.kem.as_mut().unwrap();
286                    let mut kem_output_buf = [0; MAXKEMSSLEN];
287                    let mut ciphertext_buf = [0; MAXKEMCTLEN];
288
289                    if kem.ciphertext_len() > message.len() {
290                        return Err(Error::Input);
291                    }
292
293                    let kem_output = &mut kem_output_buf[..kem.shared_secret_len()];
294                    let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()];
295                    let pubkey = &self.kem_re.as_ref().unwrap()[..kem.pub_len()];
296                    if kem.encapsulate(pubkey, kem_output, ciphertext).is_err() {
297                        return Err(Error::Kem);
298                    }
299
300                    byte_index += self.symmetricstate.encrypt_and_mix_hash(
301                        &ciphertext[..kem.ciphertext_len()],
302                        &mut message[byte_index..],
303                    )?;
304                    self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]);
305                },
306            }
307        }
308
309        if byte_index + payload.len() + TAGLEN > message.len() {
310            return Err(Error::Input);
311        }
312        byte_index +=
313            self.symmetricstate.encrypt_and_mix_hash(payload, &mut message[byte_index..])?;
314        if byte_index > MAXMSGLEN {
315            return Err(Error::Input);
316        }
317        if self.pattern_position == (self.message_patterns.len() - 1) {
318            self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1);
319        }
320        Ok(byte_index)
321    }
322
323    /// Reads a noise message from `input`
324    ///
325    /// Returns the size of the payload written to `payload`.
326    ///
327    /// # Errors
328    ///
329    /// Will result in `Error::Decrypt` if the contents couldn't be decrypted and/or the
330    /// authentication tag didn't verify.
331    ///
332    /// Will result in `StateProblem::Exhausted` if the max nonce count overflows.
333    pub fn read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> {
334        let checkpoint = self.symmetricstate.checkpoint();
335        match self._read_message(message, payload) {
336            Ok(res) => {
337                self.pattern_position += 1;
338                self.my_turn = true;
339                Ok(res)
340            },
341            Err(err) => {
342                self.symmetricstate.restore(checkpoint);
343                Err(err)
344            },
345        }
346    }
347
348    fn _read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> {
349        if message.len() > MAXMSGLEN {
350            return Err(Error::Input);
351        } else if self.my_turn {
352            return Err(StateProblem::NotTurnToRead.into());
353        } else if self.pattern_position >= self.message_patterns.len() {
354            return Err(StateProblem::HandshakeAlreadyFinished.into());
355        }
356        let last = self.pattern_position == (self.message_patterns.len() - 1);
357
358        let dh_len = self.dh_len();
359        let mut ptr = message;
360        for token in self.message_patterns[self.pattern_position].iter() {
361            match token {
362                Token::E => {
363                    if ptr.len() < dh_len {
364                        return Err(Error::Input);
365                    }
366                    self.re[..dh_len].copy_from_slice(&ptr[..dh_len]);
367                    ptr = &ptr[dh_len..];
368                    self.symmetricstate.mix_hash(&self.re[..dh_len]);
369                    if self.params.handshake.is_psk() {
370                        self.symmetricstate.mix_key(&self.re[..dh_len]);
371                    }
372                    self.re.enable();
373                },
374                Token::S => {
375                    let data = if self.symmetricstate.has_key() {
376                        if ptr.len() < dh_len + TAGLEN {
377                            return Err(Error::Input);
378                        }
379                        let temp = &ptr[..dh_len + TAGLEN];
380                        ptr = &ptr[dh_len + TAGLEN..];
381                        temp
382                    } else {
383                        if ptr.len() < dh_len {
384                            return Err(Error::Input);
385                        }
386                        let temp = &ptr[..dh_len];
387                        ptr = &ptr[dh_len..];
388                        temp
389                    };
390                    self.symmetricstate.decrypt_and_mix_hash(data, &mut self.rs[..dh_len])?;
391                    self.rs.enable();
392                },
393                Token::Psk(n) => match self.psks[*n as usize] {
394                    Some(psk) => {
395                        self.symmetricstate.mix_key_and_hash(&psk);
396                    },
397                    None => {
398                        return Err(StateProblem::MissingPsk.into());
399                    },
400                },
401                Token::Dh(t) => {
402                    let dh_out = self.dh(t)?;
403                    self.symmetricstate.mix_key(&dh_out[..self.dh_len()]);
404                },
405                #[cfg(feature = "hfs")]
406                Token::E1 => {
407                    let kem = self.kem.as_ref().ok_or(Error::Kem)?;
408                    let read_len = if self.symmetricstate.has_key() {
409                        kem.pub_len() + TAGLEN
410                    } else {
411                        kem.pub_len()
412                    };
413                    if ptr.len() < read_len {
414                        return Err(Error::Input);
415                    }
416                    let mut kem_re = [0; MAXKEMPUBLEN];
417                    self.symmetricstate
418                        .decrypt_and_mix_hash(&ptr[..read_len], &mut kem_re[..kem.pub_len()])?;
419                    self.kem_re = Some(kem_re);
420                    ptr = &ptr[read_len..];
421                },
422                #[cfg(feature = "hfs")]
423                Token::Ekem1 => {
424                    let kem = self.kem.as_ref().unwrap();
425                    let read_len = if self.symmetricstate.has_key() {
426                        kem.ciphertext_len() + TAGLEN
427                    } else {
428                        kem.ciphertext_len()
429                    };
430                    if ptr.len() < read_len {
431                        return Err(Error::Input);
432                    }
433                    let mut ciphertext_buf = [0; MAXKEMCTLEN];
434                    let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()];
435                    self.symmetricstate.decrypt_and_mix_hash(&ptr[..read_len], ciphertext)?;
436                    let mut kem_output_buf = [0; MAXKEMSSLEN];
437                    let kem_output = &mut kem_output_buf[..kem.shared_secret_len()];
438                    kem.decapsulate(ciphertext, kem_output).map_err(|_| Error::Kem)?;
439                    self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]);
440                    ptr = &ptr[read_len..];
441                },
442            }
443        }
444
445        self.symmetricstate.decrypt_and_mix_hash(ptr, payload)?;
446        if last {
447            self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1);
448        }
449        let payload_len =
450            if self.symmetricstate.has_key() { ptr.len() - TAGLEN } else { ptr.len() };
451        Ok(payload_len)
452    }
453
454    /// Set the preshared key at the specified location. It is up to the caller
455    /// to correctly set the location based on the specified handshake - Snow
456    /// won't stop you from placing a PSK in an unused slot.
457    ///
458    /// # Errors
459    ///
460    /// Will result in `Error::Input` if the PSK is not the right length or the location is out of bounds.
461    pub fn set_psk(&mut self, location: usize, key: &[u8]) -> Result<(), Error> {
462        if key.len() != PSKLEN || self.psks.len() <= location {
463            return Err(Error::Input);
464        }
465
466        let mut new_psk = [0u8; PSKLEN];
467        new_psk.copy_from_slice(key);
468        self.psks[location as usize] = Some(new_psk);
469
470        Ok(())
471    }
472
473    /// Get the remote party's static public key, if available.
474    ///
475    /// Note: will return `None` if either the chosen Noise pattern
476    /// doesn't necessitate a remote static key, *or* if the remote
477    /// static key is not yet known (as can be the case in the `XX`
478    /// pattern, for example).
479    pub fn get_remote_static(&self) -> Option<&[u8]> {
480        self.rs.get().map(|rs| &rs[..self.dh_len()])
481    }
482
483    /// Get the handshake hash.
484    ///
485    /// Returns a slice of length `Hasher.hash_len()` (i.e. HASHLEN for the chosen Hash function).
486    pub fn get_handshake_hash(&self) -> &[u8] {
487        self.symmetricstate.handshake_hash()
488    }
489
490    /// Check if this session was started with the "initiator" role.
491    pub fn is_initiator(&self) -> bool {
492        self.initiator
493    }
494
495    /// Check if the handshake is finished and `into_transport_mode()` can now be called.
496    pub fn is_handshake_finished(&self) -> bool {
497        self.pattern_position == self.message_patterns.len()
498    }
499
500    /// Check whether it is our turn to send in the handshake state machine
501    pub fn is_my_turn(&self) -> bool {
502        self.my_turn
503    }
504
505    /// Perform the split calculation and return the resulting keys.
506    ///
507    /// This returns raw key material so it should be used with care. The "risky-raw-split"
508    /// feature has to be enabled to use this function.
509    #[cfg(feature = "risky-raw-split")]
510    pub fn dangerously_get_raw_split(&mut self) -> ([u8; CIPHERKEYLEN], [u8; CIPHERKEYLEN]) {
511        let mut output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
512        self.symmetricstate.split_raw(&mut output.0, &mut output.1);
513        (output.0[..CIPHERKEYLEN].try_into().unwrap(), output.1[..CIPHERKEYLEN].try_into().unwrap())
514    }
515
516    /// Convert this `HandshakeState` into a `TransportState` with an internally stored nonce.
517    pub fn into_transport_mode(self) -> Result<TransportState, Error> {
518        self.try_into()
519    }
520
521    /// Convert this `HandshakeState` into a `StatelessTransportState` without an internally stored nonce.
522    pub fn into_stateless_transport_mode(self) -> Result<StatelessTransportState, Error> {
523        self.try_into()
524    }
525}
526
527impl fmt::Debug for HandshakeState {
528    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
529        fmt.debug_struct("HandshakeState").finish()
530    }
531}