1use crate::error::{Error, PatternProblem};
2use std::{convert::TryFrom, str::FromStr};
3
4macro_rules! message_vec {
7 ($($item:expr),*) => ({
8 let token_groups: &[&[Token]] = &[$($item),*];
9 let mut vec: MessagePatterns = Vec::with_capacity(10);
10 for group in token_groups {
11 let mut inner = Vec::with_capacity(10);
12 inner.extend_from_slice(group);
13 vec.push(inner);
14 }
15 vec
16 });
17}
18
19macro_rules! pattern_enum {
25 ($name:ident {
29 $($variant:ident),* $(,)*
30 }) => {
31 #[allow(missing_docs)]
35 #[derive(Copy, Clone, PartialEq, Debug)]
36 pub enum $name {
37 $($variant),*,
38 }
39
40 impl FromStr for $name {
41 type Err = Error;
42 fn from_str(s: &str) -> Result<Self, Self::Err> {
43 use self::$name::*;
44 match s {
45 $(
46 stringify!($variant) => Ok($variant)
47 ),
48 *,
49 _ => return Err(PatternProblem::UnsupportedHandshakeType.into())
50 }
51 }
52 }
53
54 impl $name {
55 pub fn as_str(self) -> &'static str {
57 use self::$name::*;
58 match self {
59 $(
60 $variant => stringify!($variant)
61 ),
62 *
63 }
64 }
65 }
66
67 #[doc(hidden)]
68 pub const SUPPORTED_HANDSHAKE_PATTERNS: &'static [$name] = &[$($name::$variant),*];
69 }
70}
71
72#[allow(missing_docs)]
76#[derive(Copy, Clone, PartialEq, Debug)]
77pub(crate) enum DhToken {
78 Ee,
79 Es,
80 Se,
81 Ss,
82}
83
84#[allow(missing_docs)]
88#[derive(Copy, Clone, PartialEq, Debug)]
89pub(crate) enum Token {
90 E,
91 S,
92 Dh(DhToken),
93 Psk(u8),
94 #[cfg(feature = "hfs")]
95 E1,
96 #[cfg(feature = "hfs")]
97 Ekem1,
98}
99
100#[cfg(feature = "hfs")]
101impl Token {
102 fn is_dh(&self) -> bool {
103 match *self {
104 Dh(_) => true,
105 _ => false,
106 }
107 }
108}
109
110pattern_enum! {
112 HandshakePattern {
113 N, X, K,
115
116 NN, NK, NX, XN, XK, XX, KN, KK, KX, IN, IK, IX,
118
119 NK1, NX1, X1N, X1K, XK1, X1K1, X1X, XX1, X1X1, K1N, K1K, KK1, K1K1, K1X,
121 KX1, K1X1, I1N, I1K, IK1, I1K1, I1X, IX1, I1X1
122 }
123}
124
125impl HandshakePattern {
126 pub fn is_oneway(self) -> bool {
130 matches!(self, N | X | K)
131 }
132
133 pub fn needs_local_static_key(self, initiator: bool) -> bool {
135 if initiator {
136 !matches!(self, N | NN | NK | NX | NK1 | NX1)
137 } else {
138 !matches!(self, NN | XN | KN | IN | X1N | K1N | I1N)
139 }
140 }
141
142 #[rustfmt::skip]
144 pub fn need_known_remote_pubkey(self, initiator: bool) -> bool {
145 if initiator {
146 matches!(
147 self,
148 N | K | X | NK | XK | KK | IK | NK1 | X1K | XK1 | X1K1 | K1K | KK1 | K1K1 | I1K | IK1 | I1K1
149 )
150 } else {
151 matches!(
152 self,
153 K | KN | KK | KX | K1N | K1K | KK1 | K1K1 | K1X | KX1 | K1X1
154 )
155 }
156 }
157}
158
159#[derive(Copy, Clone, PartialEq, Debug)]
161pub enum HandshakeModifier {
162 Psk(u8),
164
165 Fallback,
167
168 #[cfg(feature = "hfs")]
169 Hfs,
171}
172
173impl FromStr for HandshakeModifier {
174 type Err = Error;
175
176 fn from_str(s: &str) -> Result<Self, Self::Err> {
177 match s {
178 s if s.starts_with("psk") => {
179 Ok(HandshakeModifier::Psk(s[3..].parse().map_err(|_| PatternProblem::InvalidPsk)?))
180 },
181 "fallback" => Ok(HandshakeModifier::Fallback),
182 #[cfg(feature = "hfs")]
183 "hfs" => Ok(HandshakeModifier::Hfs),
184 _ => Err(PatternProblem::UnsupportedModifier.into()),
185 }
186 }
187}
188
189#[derive(Clone, PartialEq, Debug)]
191pub struct HandshakeModifierList {
192 pub list: Vec<HandshakeModifier>,
194}
195
196impl FromStr for HandshakeModifierList {
197 type Err = Error;
198
199 fn from_str(s: &str) -> Result<Self, Self::Err> {
200 if s.is_empty() {
201 Ok(HandshakeModifierList { list: vec![] })
202 } else {
203 let modifier_names = s.split('+');
204 let mut modifiers = vec![];
205 for modifier_name in modifier_names {
206 let modifier: HandshakeModifier = modifier_name.parse()?;
207 if modifiers.contains(&modifier) {
208 return Err(Error::Pattern(PatternProblem::UnsupportedModifier));
209 } else {
210 modifiers.push(modifier);
211 }
212 }
213 Ok(HandshakeModifierList { list: modifiers })
214 }
215 }
216}
217
218#[derive(Clone, PartialEq, Debug)]
221pub struct HandshakeChoice {
222 pub pattern: HandshakePattern,
224
225 pub modifiers: HandshakeModifierList,
227}
228
229impl HandshakeChoice {
230 pub fn is_psk(&self) -> bool {
232 for modifier in &self.modifiers.list {
233 if let HandshakeModifier::Psk(_) = *modifier {
234 return true;
235 }
236 }
237 false
238 }
239
240 pub fn is_fallback(&self) -> bool {
242 self.modifiers.list.contains(&HandshakeModifier::Fallback)
243 }
244
245 #[cfg(feature = "hfs")]
247 pub fn is_hfs(&self) -> bool {
248 self.modifiers.list.contains(&HandshakeModifier::Hfs)
249 }
250
251 fn parse_pattern_and_modifier(s: &str) -> Result<(HandshakePattern, &str), Error> {
253 for i in (1..=4).rev() {
254 if s.len() > i - 1 && s.is_char_boundary(i) {
255 if let Ok(p) = s[..i].parse() {
256 return Ok((p, &s[i..]));
257 }
258 }
259 }
260
261 Err(PatternProblem::UnsupportedHandshakeType.into())
262 }
263}
264
265impl FromStr for HandshakeChoice {
266 type Err = Error;
267
268 fn from_str(s: &str) -> Result<Self, Self::Err> {
269 let (pattern, remainder) = Self::parse_pattern_and_modifier(s)?;
270 let modifiers = remainder.parse()?;
271
272 Ok(HandshakeChoice { pattern, modifiers })
273 }
274}
275
276type PremessagePatterns = &'static [Token];
277pub(crate) type MessagePatterns = Vec<Vec<Token>>;
278
279#[derive(Debug)]
283pub(crate) struct HandshakeTokens {
284 pub premsg_pattern_i: PremessagePatterns,
285 pub premsg_pattern_r: PremessagePatterns,
286 pub msg_patterns: MessagePatterns,
287}
288
289use self::{DhToken::*, HandshakePattern::*, Token::*};
290
291type Patterns = (PremessagePatterns, PremessagePatterns, MessagePatterns);
292
293impl<'a> TryFrom<&'a HandshakeChoice> for HandshakeTokens {
294 type Error = Error;
295
296 #[allow(clippy::cognitive_complexity)]
297 fn try_from(handshake: &'a HandshakeChoice) -> Result<Self, Self::Error> {
298 #[cfg(feature = "hfs")]
300 check_hfs_and_oneway_conflict(handshake)?;
301
302 #[rustfmt::skip]
303 let mut patterns: Patterns = match handshake.pattern {
304 N => (
305 static_slice![Token: ],
306 static_slice![Token: S],
307 message_vec![&[E, Dh(Es)]]
308 ),
309 K => (
310 static_slice![Token: S],
311 static_slice![Token: S],
312 message_vec![&[E, Dh(Es), Dh(Ss)]]
313 ),
314 X => (
315 static_slice![Token: ],
316 static_slice![Token: S],
317 message_vec![&[E, Dh(Es), S, Dh(Ss)]]
318 ),
319 NN => (
320 static_slice![Token: ],
321 static_slice![Token: ],
322 message_vec![&[E], &[E, Dh(Ee)]]
323 ),
324 NK => (
325 static_slice![Token: ],
326 static_slice![Token: S],
327 message_vec![&[E, Dh(Es)], &[E, Dh(Ee)]]
328 ),
329 NX => (
330 static_slice![Token: ],
331 static_slice![Token: ],
332 message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)]]
333 ),
334 XN => (
335 static_slice![Token: ],
336 static_slice![Token: ],
337 message_vec![&[E], &[E, Dh(Ee)], &[S, Dh(Se)]]
338 ),
339 XK => (
340 static_slice![Token: ],
341 static_slice![Token: S],
342 message_vec![&[E, Dh(Es)], &[E, Dh(Ee)], &[S, Dh(Se)]]
343 ),
344 XX => (
345 static_slice![Token: ],
346 static_slice![Token: ],
347 message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)], &[S, Dh(Se)]],
348 ),
349 KN => (
350 static_slice![Token: S],
351 static_slice![Token: ],
352 message_vec![&[E], &[E, Dh(Ee), Dh(Se)]],
353 ),
354 KK => (
355 static_slice![Token: S],
356 static_slice![Token: S],
357 message_vec![&[E, Dh(Es), Dh(Ss)], &[E, Dh(Ee), Dh(Se)]],
358 ),
359 KX => (
360 static_slice![Token: S],
361 static_slice![Token: ],
362 message_vec![&[E], &[E, Dh(Ee), Dh(Se), S, Dh(Es)]],
363 ),
364 IN => (
365 static_slice![Token: ],
366 static_slice![Token: ],
367 message_vec![&[E, S], &[E, Dh(Ee), Dh(Se)]],
368 ),
369 IK => (
370 static_slice![Token: ],
371 static_slice![Token: S],
372 message_vec![&[E, Dh(Es), S, Dh(Ss)], &[E, Dh(Ee), Dh(Se)]],
373 ),
374 IX => (
375 static_slice![Token: ],
376 static_slice![Token: ],
377 message_vec![&[E, S], &[E, Dh(Ee), Dh(Se), S, Dh(Es)]],
378 ),
379 NK1 => (
380 static_slice![Token: ],
381 static_slice![Token: S],
382 message_vec![&[E], &[E, Dh(Ee), Dh(Es)]],
383 ),
384 NX1 => (
385 static_slice![Token: ],
386 static_slice![Token: ],
387 message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Es)]]
388 ),
389 X1N => (
390 static_slice![Token: ],
391 static_slice![Token: ],
392 message_vec![&[E], &[E, Dh(Ee)], &[S], &[Dh(Se)]]
393 ),
394 X1K => (
395 static_slice![Token: ],
396 static_slice![Token: S],
397 message_vec![&[E, Dh(Es)], &[E, Dh(Ee)], &[S], &[Dh(Se)]]
398 ),
399 XK1 => (
400 static_slice![Token: ],
401 static_slice![Token: S],
402 message_vec![&[E], &[E, Dh(Ee), Dh(Es)], &[S, Dh(Se)]]
403 ),
404 X1K1 => (
405 static_slice![Token: ],
406 static_slice![Token: S],
407 message_vec![&[E], &[E, Dh(Ee), Dh(Es)], &[S], &[Dh(Se)]]
408 ),
409 X1X => (
410 static_slice![Token: ],
411 static_slice![Token: ],
412 message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)], &[S], &[Dh(Se)]],
413 ),
414 XX1 => (
415 static_slice![Token: ],
416 static_slice![Token: ],
417 message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Es), S, Dh(Se)]],
418 ),
419 X1X1 => (
420 static_slice![Token: ],
421 static_slice![Token: ],
422 message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Es), S], &[Dh(Se)]],
423 ),
424 K1N => (
425 static_slice![Token: S],
426 static_slice![Token: ],
427 message_vec![&[E], &[E, Dh(Ee)], &[Dh(Se)]],
428 ),
429 K1K => (
430 static_slice![Token: S],
431 static_slice![Token: S],
432 message_vec![&[E, Dh(Es)], &[E, Dh(Ee)], &[Dh(Se)]],
433 ),
434 KK1 => (
435 static_slice![Token: S],
436 static_slice![Token: S],
437 message_vec![&[E], &[E, Dh(Ee), Dh(Se), Dh(Es)]],
438 ),
439 K1K1 => (
440 static_slice![Token: S],
441 static_slice![Token: S],
442 message_vec![&[E], &[E, Dh(Ee), Dh(Es)], &[Dh(Se)]],
443 ),
444 K1X => (
445 static_slice![Token: S],
446 static_slice![Token: ],
447 message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)], &[Dh(Se)]],
448 ),
449 KX1 => (
450 static_slice![Token: S],
451 static_slice![Token: ],
452 message_vec![&[E], &[E, Dh(Ee), Dh(Se), S], &[Dh(Es)]],
453 ),
454 K1X1 => (
455 static_slice![Token: S],
456 static_slice![Token: ],
457 message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Se), Dh(Es)]],
458 ),
459 I1N => (
460 static_slice![Token: ],
461 static_slice![Token: ],
462 message_vec![&[E, S], &[E, Dh(Ee)], &[Dh(Se)]],
463 ),
464 I1K => (
465 static_slice![Token: ],
466 static_slice![Token: S],
467 message_vec![&[E, Dh(Es), S], &[E, Dh(Ee)], &[Dh(Se)]],
468 ),
469 IK1 => (
470 static_slice![Token: ],
471 static_slice![Token: S],
472 message_vec![&[E, S], &[E, Dh(Ee), Dh(Se), Dh(Es)]],
473 ),
474 I1K1 => (
475 static_slice![Token: ],
476 static_slice![Token: S],
477 message_vec![&[E, S], &[E, Dh(Ee), Dh(Es)], &[Dh(Se)]],
478 ),
479 I1X => (
480 static_slice![Token: ],
481 static_slice![Token: ],
482 message_vec![&[E, S], &[E, Dh(Ee), S, Dh(Es)], &[Dh(Se)]],
483 ),
484 IX1 => (
485 static_slice![Token: ],
486 static_slice![Token: ],
487 message_vec![&[E, S], &[E, Dh(Ee), Dh(Se), S], &[Dh(Es)]],
488 ),
489 I1X1 => (
490 static_slice![Token: ],
491 static_slice![Token: ],
492 message_vec![&[E, S], &[E, Dh(Ee), S], &[Dh(Se), Dh(Es)]],
493 ),
494 };
495
496 for modifier in handshake.modifiers.list.iter() {
497 match modifier {
498 HandshakeModifier::Psk(n) => apply_psk_modifier(&mut patterns, *n)?,
499 #[cfg(feature = "hfs")]
500 HandshakeModifier::Hfs => apply_hfs_modifier(&mut patterns),
501 _ => return Err(PatternProblem::UnsupportedModifier.into()),
502 }
503 }
504
505 Ok(HandshakeTokens {
506 premsg_pattern_i: patterns.0,
507 premsg_pattern_r: patterns.1,
508 msg_patterns: patterns.2,
509 })
510 }
511}
512
513#[cfg(feature = "hfs")]
514fn check_hfs_and_oneway_conflict(handshake: &HandshakeChoice) -> Result<(), Error> {
519 if handshake.is_hfs() && handshake.pattern.is_oneway() {
520 return Err(PatternProblem::UnsupportedModifier.into());
521 } else {
522 Ok(())
523 }
524}
525
526fn apply_psk_modifier(patterns: &mut Patterns, n: u8) -> Result<(), Error> {
528 let tokens = patterns
529 .2
530 .get_mut((n as usize).saturating_sub(1))
531 .ok_or(Error::Pattern(PatternProblem::InvalidPsk))?;
532 if n == 0 {
533 tokens.insert(0, Token::Psk(n));
534 } else {
535 tokens.push(Token::Psk(n));
536 }
537 Ok(())
538}
539
540#[cfg(feature = "hfs")]
541fn apply_hfs_modifier(patterns: &mut Patterns) {
542 let mut e1_insert_idx = None;
554 for msg in patterns.2.iter_mut() {
555 if let Some(e_idx) = msg.iter().position(|x| *x == Token::E) {
556 if let Some(dh_idx) = msg.iter().position(|x| x.is_dh()) {
557 e1_insert_idx = Some(dh_idx + 1);
558 } else {
559 e1_insert_idx = Some(e_idx + 1);
560 }
561 }
562 if let Some(idx) = e1_insert_idx {
563 msg.insert(idx, Token::E1);
564 break;
565 }
566 }
567
568 let mut ee_insert_idx = None;
570 for msg in patterns.2.iter_mut() {
571 if let Some(ee_idx) = msg.iter().position(|x| *x == Token::Dh(Ee)) {
572 ee_insert_idx = Some(ee_idx + 1)
573 }
574 if let Some(idx) = ee_insert_idx {
575 msg.insert(idx, Token::Ekem1);
576 break;
577 }
578 }
579
580 assert!(
583 !(e1_insert_idx.is_some() ^ ee_insert_idx.is_some()),
584 "handshake messages contain one of the ['e1', 'ekem1'] tokens, but not the other",
585 );
586}