1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
use crate::{
    cipherstate::CipherState,
    constants::{CIPHERKEYLEN, MAXHASHLEN},
    error::Error,
    types::Hash,
};

#[derive(Copy, Clone)]
pub(crate) struct SymmetricStateData {
    h:       [u8; MAXHASHLEN],
    ck:      [u8; MAXHASHLEN],
    has_key: bool,
}

impl Default for SymmetricStateData {
    fn default() -> Self {
        SymmetricStateData {
            h:       [0u8; MAXHASHLEN],
            ck:      [0u8; MAXHASHLEN],
            has_key: false,
        }
    }
}

pub(crate) struct SymmetricState {
    cipherstate: CipherState,
    hasher:      Box<dyn Hash>,
    inner:       SymmetricStateData,
}

impl SymmetricState {
    pub fn new(cipherstate: CipherState, hasher: Box<dyn Hash>) -> SymmetricState {
        SymmetricState { cipherstate, hasher, inner: SymmetricStateData::default() }
    }

    pub fn initialize(&mut self, handshake_name: &str) {
        if handshake_name.len() <= self.hasher.hash_len() {
            copy_slices!(handshake_name.as_bytes(), self.inner.h);
        } else {
            self.hasher.reset();
            self.hasher.input(handshake_name.as_bytes());
            self.hasher.result(&mut self.inner.h);
        }
        copy_slices!(self.inner.h, &mut self.inner.ck);
        self.inner.has_key = false;
    }

    pub fn mix_key(&mut self, data: &[u8]) {
        let hash_len = self.hasher.hash_len();
        let mut hkdf_output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
        self.hasher.hkdf(
            &self.inner.ck[..hash_len],
            data,
            2,
            &mut hkdf_output.0,
            &mut hkdf_output.1,
            &mut [],
        );
        copy_slices!(hkdf_output.0, &mut self.inner.ck);
        self.cipherstate.set(&hkdf_output.1[..CIPHERKEYLEN], 0);
        self.inner.has_key = true;
    }

    pub fn mix_hash(&mut self, data: &[u8]) {
        let hash_len = self.hasher.hash_len();
        self.hasher.reset();
        self.hasher.input(&self.inner.h[..hash_len]);
        self.hasher.input(data);
        self.hasher.result(&mut self.inner.h);
    }

    pub fn mix_key_and_hash(&mut self, data: &[u8]) {
        let hash_len = self.hasher.hash_len();
        let mut hkdf_output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
        self.hasher.hkdf(
            &self.inner.ck[..hash_len],
            data,
            3,
            &mut hkdf_output.0,
            &mut hkdf_output.1,
            &mut hkdf_output.2,
        );
        copy_slices!(hkdf_output.0, &mut self.inner.ck);
        self.mix_hash(&hkdf_output.1[..hash_len]);
        self.cipherstate.set(&hkdf_output.2[..CIPHERKEYLEN], 0);
    }

    pub fn has_key(&self) -> bool {
        self.inner.has_key
    }

    /// Encrypt a message and mixes in the hash of the output
    pub fn encrypt_and_mix_hash(
        &mut self,
        plaintext: &[u8],
        out: &mut [u8],
    ) -> Result<usize, Error> {
        let hash_len = self.hasher.hash_len();
        let output_len = if self.inner.has_key {
            self.cipherstate.encrypt_ad(&self.inner.h[..hash_len], plaintext, out)?
        } else {
            copy_slices!(plaintext, out);
            plaintext.len()
        };
        self.mix_hash(&out[..output_len]);
        Ok(output_len)
    }

    pub fn decrypt_and_mix_hash(&mut self, data: &[u8], out: &mut [u8]) -> Result<usize, Error> {
        let hash_len = self.hasher.hash_len();
        let payload_len = if self.inner.has_key {
            self.cipherstate.decrypt_ad(&self.inner.h[..hash_len], data, out)?
        } else {
            if out.len() < data.len() {
                return Err(Error::Decrypt);
            }
            copy_slices!(data, out);
            data.len()
        };
        self.mix_hash(data);
        Ok(payload_len)
    }

    pub fn split(&mut self, child1: &mut CipherState, child2: &mut CipherState) {
        let mut hkdf_output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
        self.split_raw(&mut hkdf_output.0, &mut hkdf_output.1);
        child1.set(&hkdf_output.0[..CIPHERKEYLEN], 0);
        child2.set(&hkdf_output.1[..CIPHERKEYLEN], 0);
    }

    pub fn split_raw(&mut self, out1: &mut [u8], out2: &mut [u8]) {
        let hash_len = self.hasher.hash_len();
        self.hasher.hkdf(&self.inner.ck[..hash_len], &[0u8; 0], 2, out1, out2, &mut []);
    }

    pub(crate) fn checkpoint(&mut self) -> SymmetricStateData {
        self.inner
    }

    pub(crate) fn restore(&mut self, checkpoint: SymmetricStateData) {
        self.inner = checkpoint;
    }

    pub fn handshake_hash(&self) -> &[u8] {
        let hash_len = self.hasher.hash_len();
        &self.inner.h[..hash_len]
    }
}