1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use core::mem;
4
5use crate::crypto::hash;
6use crate::msgs::codec::Codec;
7use crate::msgs::enums::HashAlgorithm;
8use crate::msgs::handshake::HandshakeMessagePayload;
9use crate::msgs::message::{Message, MessagePayload};
10
11#[derive(Clone)]
17pub(crate) struct HandshakeHashBuffer {
18 buffer: Vec<u8>,
19 client_auth_enabled: bool,
20}
21
22impl HandshakeHashBuffer {
23 pub(crate) fn new() -> Self {
24 Self {
25 buffer: Vec::new(),
26 client_auth_enabled: false,
27 }
28 }
29
30 pub(crate) fn set_client_auth_enabled(&mut self) {
33 self.client_auth_enabled = true;
34 }
35
36 pub(crate) fn add_message(&mut self, m: &Message<'_>) {
38 if let MessagePayload::Handshake { encoded, .. } = &m.payload {
39 self.buffer
40 .extend_from_slice(encoded.bytes());
41 }
42 }
43
44 #[cfg(all(test, any(feature = "ring", feature = "aws_lc_rs")))]
46 fn update_raw(&mut self, buf: &[u8]) {
47 self.buffer.extend_from_slice(buf);
48 }
49
50 pub(crate) fn hash_given(
52 &self,
53 provider: &'static dyn hash::Hash,
54 extra: &[u8],
55 ) -> hash::Output {
56 let mut ctx = provider.start();
57 ctx.update(&self.buffer);
58 ctx.update(extra);
59 ctx.finish()
60 }
61
62 pub(crate) fn start_hash(self, provider: &'static dyn hash::Hash) -> HandshakeHash {
64 let mut ctx = provider.start();
65 ctx.update(&self.buffer);
66 HandshakeHash {
67 provider,
68 ctx,
69 client_auth: match self.client_auth_enabled {
70 true => Some(self.buffer),
71 false => None,
72 },
73 }
74 }
75}
76
77pub(crate) struct HandshakeHash {
85 provider: &'static dyn hash::Hash,
86 ctx: Box<dyn hash::Context>,
87
88 client_auth: Option<Vec<u8>>,
90}
91
92impl HandshakeHash {
93 pub(crate) fn abandon_client_auth(&mut self) {
96 self.client_auth = None;
97 }
98
99 pub(crate) fn add_message(&mut self, m: &Message<'_>) -> &mut Self {
101 if let MessagePayload::Handshake { encoded, .. } = &m.payload {
102 self.update_raw(encoded.bytes());
103 }
104 self
105 }
106
107 fn update_raw(&mut self, buf: &[u8]) -> &mut Self {
109 self.ctx.update(buf);
110
111 if let Some(buffer) = &mut self.client_auth {
112 buffer.extend_from_slice(buf);
113 }
114
115 self
116 }
117
118 pub(crate) fn hash_given(&self, extra: &[u8]) -> hash::Output {
121 let mut ctx = self.ctx.fork();
122 ctx.update(extra);
123 ctx.finish()
124 }
125
126 pub(crate) fn into_hrr_buffer(self) -> HandshakeHashBuffer {
127 let old_hash = self.ctx.finish();
128 let old_handshake_hash_msg =
129 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
130
131 HandshakeHashBuffer {
132 client_auth_enabled: self.client_auth.is_some(),
133 buffer: old_handshake_hash_msg.get_encoding(),
134 }
135 }
136
137 pub(crate) fn rollup_for_hrr(&mut self) {
141 let ctx = &mut self.ctx;
142
143 let old_ctx = mem::replace(ctx, self.provider.start());
144 let old_hash = old_ctx.finish();
145 let old_handshake_hash_msg =
146 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
147
148 self.update_raw(&old_handshake_hash_msg.get_encoding());
149 }
150
151 pub(crate) fn current_hash(&self) -> hash::Output {
153 self.ctx.fork_finish()
154 }
155
156 #[cfg(feature = "tls12")]
160 pub(crate) fn take_handshake_buf(&mut self) -> Option<Vec<u8>> {
161 self.client_auth.take()
162 }
163
164 pub(crate) fn algorithm(&self) -> HashAlgorithm {
166 self.provider.algorithm()
167 }
168}
169
170impl Clone for HandshakeHash {
171 fn clone(&self) -> Self {
172 Self {
173 provider: self.provider,
174 ctx: self.ctx.fork(),
175 client_auth: self.client_auth.clone(),
176 }
177 }
178}
179
180test_for_each_provider! {
181 use super::HandshakeHashBuffer;
182 use provider::hash::SHA256;
183
184 #[test]
185 fn hashes_correctly() {
186 let mut hhb = HandshakeHashBuffer::new();
187 hhb.update_raw(b"hello");
188 assert_eq!(hhb.buffer.len(), 5);
189 let mut hh = hhb.start_hash(&SHA256);
190 assert!(hh.client_auth.is_none());
191 hh.update_raw(b"world");
192 let h = hh.current_hash();
193 let h = h.as_ref();
194 assert_eq!(h[0], 0x93);
195 assert_eq!(h[1], 0x6a);
196 assert_eq!(h[2], 0x18);
197 assert_eq!(h[3], 0x5c);
198 }
199
200 #[cfg(feature = "tls12")]
201 #[test]
202 fn buffers_correctly() {
203 let mut hhb = HandshakeHashBuffer::new();
204 hhb.set_client_auth_enabled();
205 hhb.update_raw(b"hello");
206 assert_eq!(hhb.buffer.len(), 5);
207 let mut hh = hhb.start_hash(&SHA256);
208 assert_eq!(
209 hh.client_auth
210 .as_ref()
211 .map(|buf| buf.len()),
212 Some(5)
213 );
214 hh.update_raw(b"world");
215 assert_eq!(
216 hh.client_auth
217 .as_ref()
218 .map(|buf| buf.len()),
219 Some(10)
220 );
221 let h = hh.current_hash();
222 let h = h.as_ref();
223 assert_eq!(h[0], 0x93);
224 assert_eq!(h[1], 0x6a);
225 assert_eq!(h[2], 0x18);
226 assert_eq!(h[3], 0x5c);
227 let buf = hh.take_handshake_buf();
228 assert_eq!(Some(b"helloworld".to_vec()), buf);
229 }
230
231 #[test]
232 fn abandon() {
233 let mut hhb = HandshakeHashBuffer::new();
234 hhb.set_client_auth_enabled();
235 hhb.update_raw(b"hello");
236 assert_eq!(hhb.buffer.len(), 5);
237 let mut hh = hhb.start_hash(&SHA256);
238 assert_eq!(
239 hh.client_auth
240 .as_ref()
241 .map(|buf| buf.len()),
242 Some(5)
243 );
244 hh.abandon_client_auth();
245 assert_eq!(hh.client_auth, None);
246 hh.update_raw(b"world");
247 assert_eq!(hh.client_auth, None);
248 let h = hh.current_hash();
249 let h = h.as_ref();
250 assert_eq!(h[0], 0x93);
251 assert_eq!(h[1], 0x6a);
252 assert_eq!(h[2], 0x18);
253 assert_eq!(h[3], 0x5c);
254 }
255
256 #[test]
257 fn clones_correctly() {
258 let mut hhb = HandshakeHashBuffer::new();
259 hhb.set_client_auth_enabled();
260 hhb.update_raw(b"hello");
261 assert_eq!(hhb.buffer.len(), 5);
262
263 let mut hhb_prime = hhb.clone();
265 assert_eq!(hhb_prime.buffer, hhb.buffer);
266 assert!(hhb_prime.client_auth_enabled);
267
268 hhb_prime.update_raw(b"world");
270 assert_eq!(hhb_prime.buffer.len(), 10);
271 assert_ne!(hhb.buffer, hhb_prime.buffer);
272
273 let hh = hhb.start_hash(&SHA256);
274 let hh_hash = hh.current_hash();
275 let hh_hash = hh_hash.as_ref();
276
277 let mut hh_prime = hh.clone();
279 let hh_prime_hash = hh_prime.current_hash();
280 let hh_prime_hash = hh_prime_hash.as_ref();
281 assert_eq!(hh_hash, hh_prime_hash);
282
283 hh_prime.update_raw(b"goodbye");
285 assert_eq!(hh.current_hash().as_ref(), hh_hash);
286 assert_ne!(hh_prime.current_hash().as_ref(), hh_hash);
287 }
288}