snow/
symmetricstate.rs

1use crate::{
2    cipherstate::CipherState,
3    constants::{CIPHERKEYLEN, MAXHASHLEN},
4    error::Error,
5    types::Hash,
6};
7
8#[derive(Copy, Clone)]
9pub(crate) struct SymmetricStateData {
10    h:       [u8; MAXHASHLEN],
11    ck:      [u8; MAXHASHLEN],
12    has_key: bool,
13}
14
15impl Default for SymmetricStateData {
16    fn default() -> Self {
17        SymmetricStateData {
18            h:       [0u8; MAXHASHLEN],
19            ck:      [0u8; MAXHASHLEN],
20            has_key: false,
21        }
22    }
23}
24
25pub(crate) struct SymmetricState {
26    cipherstate: CipherState,
27    hasher:      Box<dyn Hash>,
28    inner:       SymmetricStateData,
29}
30
31impl SymmetricState {
32    pub fn new(cipherstate: CipherState, hasher: Box<dyn Hash>) -> SymmetricState {
33        SymmetricState { cipherstate, hasher, inner: SymmetricStateData::default() }
34    }
35
36    pub fn initialize(&mut self, handshake_name: &str) {
37        if handshake_name.len() <= self.hasher.hash_len() {
38            copy_slices!(handshake_name.as_bytes(), self.inner.h);
39        } else {
40            self.hasher.reset();
41            self.hasher.input(handshake_name.as_bytes());
42            self.hasher.result(&mut self.inner.h);
43        }
44        copy_slices!(self.inner.h, &mut self.inner.ck);
45        self.inner.has_key = false;
46    }
47
48    pub fn mix_key(&mut self, data: &[u8]) {
49        let hash_len = self.hasher.hash_len();
50        let mut hkdf_output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
51        self.hasher.hkdf(
52            &self.inner.ck[..hash_len],
53            data,
54            2,
55            &mut hkdf_output.0,
56            &mut hkdf_output.1,
57            &mut [],
58        );
59        copy_slices!(hkdf_output.0, &mut self.inner.ck);
60        self.cipherstate.set(&hkdf_output.1[..CIPHERKEYLEN], 0);
61        self.inner.has_key = true;
62    }
63
64    pub fn mix_hash(&mut self, data: &[u8]) {
65        let hash_len = self.hasher.hash_len();
66        self.hasher.reset();
67        self.hasher.input(&self.inner.h[..hash_len]);
68        self.hasher.input(data);
69        self.hasher.result(&mut self.inner.h);
70    }
71
72    pub fn mix_key_and_hash(&mut self, data: &[u8]) {
73        let hash_len = self.hasher.hash_len();
74        let mut hkdf_output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
75        self.hasher.hkdf(
76            &self.inner.ck[..hash_len],
77            data,
78            3,
79            &mut hkdf_output.0,
80            &mut hkdf_output.1,
81            &mut hkdf_output.2,
82        );
83        copy_slices!(hkdf_output.0, &mut self.inner.ck);
84        self.mix_hash(&hkdf_output.1[..hash_len]);
85        self.cipherstate.set(&hkdf_output.2[..CIPHERKEYLEN], 0);
86    }
87
88    pub fn has_key(&self) -> bool {
89        self.inner.has_key
90    }
91
92    /// Encrypt a message and mixes in the hash of the output
93    pub fn encrypt_and_mix_hash(
94        &mut self,
95        plaintext: &[u8],
96        out: &mut [u8],
97    ) -> Result<usize, Error> {
98        let hash_len = self.hasher.hash_len();
99        let output_len = if self.inner.has_key {
100            self.cipherstate.encrypt_ad(&self.inner.h[..hash_len], plaintext, out)?
101        } else {
102            copy_slices!(plaintext, out);
103            plaintext.len()
104        };
105        self.mix_hash(&out[..output_len]);
106        Ok(output_len)
107    }
108
109    pub fn decrypt_and_mix_hash(&mut self, data: &[u8], out: &mut [u8]) -> Result<usize, Error> {
110        let hash_len = self.hasher.hash_len();
111        let payload_len = if self.inner.has_key {
112            self.cipherstate.decrypt_ad(&self.inner.h[..hash_len], data, out)?
113        } else {
114            if out.len() < data.len() {
115                return Err(Error::Decrypt);
116            }
117            copy_slices!(data, out);
118            data.len()
119        };
120        self.mix_hash(data);
121        Ok(payload_len)
122    }
123
124    pub fn split(&mut self, child1: &mut CipherState, child2: &mut CipherState) {
125        let mut hkdf_output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
126        self.split_raw(&mut hkdf_output.0, &mut hkdf_output.1);
127        child1.set(&hkdf_output.0[..CIPHERKEYLEN], 0);
128        child2.set(&hkdf_output.1[..CIPHERKEYLEN], 0);
129    }
130
131    pub fn split_raw(&mut self, out1: &mut [u8], out2: &mut [u8]) {
132        let hash_len = self.hasher.hash_len();
133        self.hasher.hkdf(&self.inner.ck[..hash_len], &[0u8; 0], 2, out1, out2, &mut []);
134    }
135
136    pub(crate) fn checkpoint(&mut self) -> SymmetricStateData {
137        self.inner
138    }
139
140    pub(crate) fn restore(&mut self, checkpoint: SymmetricStateData) {
141        self.inner = checkpoint;
142    }
143
144    pub fn handshake_hash(&self) -> &[u8] {
145        let hash_len = self.hasher.hash_len();
146        &self.inner.h[..hash_len]
147    }
148}