1use crate::{
2 cipherstate::CipherState,
3 constants::{CIPHERKEYLEN, MAXHASHLEN},
4 error::Error,
5 types::Hash,
6};
7
8#[derive(Copy, Clone)]
9pub(crate) struct SymmetricStateData {
10 h: [u8; MAXHASHLEN],
11 ck: [u8; MAXHASHLEN],
12 has_key: bool,
13}
14
15impl Default for SymmetricStateData {
16 fn default() -> Self {
17 SymmetricStateData {
18 h: [0u8; MAXHASHLEN],
19 ck: [0u8; MAXHASHLEN],
20 has_key: false,
21 }
22 }
23}
24
25pub(crate) struct SymmetricState {
26 cipherstate: CipherState,
27 hasher: Box<dyn Hash>,
28 inner: SymmetricStateData,
29}
30
31impl SymmetricState {
32 pub fn new(cipherstate: CipherState, hasher: Box<dyn Hash>) -> SymmetricState {
33 SymmetricState { cipherstate, hasher, inner: SymmetricStateData::default() }
34 }
35
36 pub fn initialize(&mut self, handshake_name: &str) {
37 if handshake_name.len() <= self.hasher.hash_len() {
38 copy_slices!(handshake_name.as_bytes(), self.inner.h);
39 } else {
40 self.hasher.reset();
41 self.hasher.input(handshake_name.as_bytes());
42 self.hasher.result(&mut self.inner.h);
43 }
44 copy_slices!(self.inner.h, &mut self.inner.ck);
45 self.inner.has_key = false;
46 }
47
48 pub fn mix_key(&mut self, data: &[u8]) {
49 let hash_len = self.hasher.hash_len();
50 let mut hkdf_output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
51 self.hasher.hkdf(
52 &self.inner.ck[..hash_len],
53 data,
54 2,
55 &mut hkdf_output.0,
56 &mut hkdf_output.1,
57 &mut [],
58 );
59 copy_slices!(hkdf_output.0, &mut self.inner.ck);
60 self.cipherstate.set(&hkdf_output.1[..CIPHERKEYLEN], 0);
61 self.inner.has_key = true;
62 }
63
64 pub fn mix_hash(&mut self, data: &[u8]) {
65 let hash_len = self.hasher.hash_len();
66 self.hasher.reset();
67 self.hasher.input(&self.inner.h[..hash_len]);
68 self.hasher.input(data);
69 self.hasher.result(&mut self.inner.h);
70 }
71
72 pub fn mix_key_and_hash(&mut self, data: &[u8]) {
73 let hash_len = self.hasher.hash_len();
74 let mut hkdf_output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
75 self.hasher.hkdf(
76 &self.inner.ck[..hash_len],
77 data,
78 3,
79 &mut hkdf_output.0,
80 &mut hkdf_output.1,
81 &mut hkdf_output.2,
82 );
83 copy_slices!(hkdf_output.0, &mut self.inner.ck);
84 self.mix_hash(&hkdf_output.1[..hash_len]);
85 self.cipherstate.set(&hkdf_output.2[..CIPHERKEYLEN], 0);
86 }
87
88 pub fn has_key(&self) -> bool {
89 self.inner.has_key
90 }
91
92 pub fn encrypt_and_mix_hash(
94 &mut self,
95 plaintext: &[u8],
96 out: &mut [u8],
97 ) -> Result<usize, Error> {
98 let hash_len = self.hasher.hash_len();
99 let output_len = if self.inner.has_key {
100 self.cipherstate.encrypt_ad(&self.inner.h[..hash_len], plaintext, out)?
101 } else {
102 copy_slices!(plaintext, out);
103 plaintext.len()
104 };
105 self.mix_hash(&out[..output_len]);
106 Ok(output_len)
107 }
108
109 pub fn decrypt_and_mix_hash(&mut self, data: &[u8], out: &mut [u8]) -> Result<usize, Error> {
110 let hash_len = self.hasher.hash_len();
111 let payload_len = if self.inner.has_key {
112 self.cipherstate.decrypt_ad(&self.inner.h[..hash_len], data, out)?
113 } else {
114 if out.len() < data.len() {
115 return Err(Error::Decrypt);
116 }
117 copy_slices!(data, out);
118 data.len()
119 };
120 self.mix_hash(data);
121 Ok(payload_len)
122 }
123
124 pub fn split(&mut self, child1: &mut CipherState, child2: &mut CipherState) {
125 let mut hkdf_output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
126 self.split_raw(&mut hkdf_output.0, &mut hkdf_output.1);
127 child1.set(&hkdf_output.0[..CIPHERKEYLEN], 0);
128 child2.set(&hkdf_output.1[..CIPHERKEYLEN], 0);
129 }
130
131 pub fn split_raw(&mut self, out1: &mut [u8], out2: &mut [u8]) {
132 let hash_len = self.hasher.hash_len();
133 self.hasher.hkdf(&self.inner.ck[..hash_len], &[0u8; 0], 2, out1, out2, &mut []);
134 }
135
136 pub(crate) fn checkpoint(&mut self) -> SymmetricStateData {
137 self.inner
138 }
139
140 pub(crate) fn restore(&mut self, checkpoint: SymmetricStateData) {
141 self.inner = checkpoint;
142 }
143
144 pub fn handshake_hash(&self) -> &[u8] {
145 let hash_len = self.hasher.hash_len();
146 &self.inner.h[..hash_len]
147 }
148}