#[cfg(feature = "risky-raw-split")]
use crate::constants::{CIPHERKEYLEN, MAXHASHLEN};
#[cfg(feature = "hfs")]
use crate::constants::{MAXKEMCTLEN, MAXKEMPUBLEN, MAXKEMSSLEN};
#[cfg(feature = "hfs")]
use crate::types::Kem;
use crate::{
cipherstate::{CipherState, CipherStates},
constants::{MAXDHLEN, MAXMSGLEN, PSKLEN, TAGLEN},
error::{Error, InitStage, StateProblem},
params::{DhToken, HandshakeTokens, MessagePatterns, NoiseParams, Token},
stateless_transportstate::StatelessTransportState,
symmetricstate::SymmetricState,
transportstate::TransportState,
types::{Dh, Hash, Random},
utils::Toggle,
};
use std::{
convert::{TryFrom, TryInto},
fmt,
};
pub struct HandshakeState {
pub(crate) rng: Box<dyn Random>,
pub(crate) symmetricstate: SymmetricState,
pub(crate) cipherstates: CipherStates,
pub(crate) s: Toggle<Box<dyn Dh>>,
pub(crate) e: Toggle<Box<dyn Dh>>,
pub(crate) fixed_ephemeral: bool,
pub(crate) rs: Toggle<[u8; MAXDHLEN]>,
pub(crate) re: Toggle<[u8; MAXDHLEN]>,
pub(crate) initiator: bool,
pub(crate) params: NoiseParams,
pub(crate) psks: [Option<[u8; PSKLEN]>; 10],
#[cfg(feature = "hfs")]
pub(crate) kem: Option<Box<dyn Kem>>,
#[cfg(feature = "hfs")]
pub(crate) kem_re: Option<[u8; MAXKEMPUBLEN]>,
pub(crate) my_turn: bool,
pub(crate) message_patterns: MessagePatterns,
pub(crate) pattern_position: usize,
}
impl HandshakeState {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
rng: Box<dyn Random>,
cipherstate: CipherState,
hasher: Box<dyn Hash>,
s: Toggle<Box<dyn Dh>>,
e: Toggle<Box<dyn Dh>>,
fixed_ephemeral: bool,
rs: Toggle<[u8; MAXDHLEN]>,
re: Toggle<[u8; MAXDHLEN]>,
initiator: bool,
params: NoiseParams,
psks: [Option<[u8; PSKLEN]>; 10],
prologue: &[u8],
cipherstates: CipherStates,
) -> Result<HandshakeState, Error> {
if (s.is_on() && e.is_on() && s.pub_len() != e.pub_len())
|| (s.is_on() && rs.is_on() && s.pub_len() > rs.len())
|| (s.is_on() && re.is_on() && s.pub_len() > re.len())
{
return Err(InitStage::ValidateKeyLengths.into());
}
let tokens = HandshakeTokens::try_from(¶ms.handshake)?;
let mut symmetricstate = SymmetricState::new(cipherstate, hasher);
symmetricstate.initialize(¶ms.name);
symmetricstate.mix_hash(prologue);
let dh_len = s.pub_len();
if initiator {
for token in tokens.premsg_pattern_i {
symmetricstate.mix_hash(
match *token {
Token::S => &s,
Token::E => &e,
_ => unreachable!(),
}
.get()
.ok_or(StateProblem::MissingKeyMaterial)?
.pubkey(),
);
}
for token in tokens.premsg_pattern_r {
symmetricstate.mix_hash(
&match *token {
Token::S => &rs,
Token::E => &re,
_ => unreachable!(),
}
.get()
.ok_or(StateProblem::MissingKeyMaterial)?[..dh_len],
);
}
} else {
for token in tokens.premsg_pattern_i {
symmetricstate.mix_hash(
&match *token {
Token::S => &rs,
Token::E => &re,
_ => unreachable!(),
}
.get()
.ok_or(StateProblem::MissingKeyMaterial)?[..dh_len],
);
}
for token in tokens.premsg_pattern_r {
symmetricstate.mix_hash(
match *token {
Token::S => &s,
Token::E => &e,
_ => unreachable!(),
}
.get()
.ok_or(StateProblem::MissingKeyMaterial)?
.pubkey(),
);
}
}
Ok(HandshakeState {
rng,
symmetricstate,
cipherstates,
s,
e,
fixed_ephemeral,
rs,
re,
initiator,
params,
psks,
#[cfg(feature = "hfs")]
kem: None,
#[cfg(feature = "hfs")]
kem_re: None,
my_turn: initiator,
message_patterns: tokens.msg_patterns,
pattern_position: 0,
})
}
pub(crate) fn dh_len(&self) -> usize {
self.s.pub_len()
}
#[cfg(feature = "hfs")]
pub(crate) fn set_kem(&mut self, kem: Box<dyn Kem>) {
self.kem = Some(kem);
}
fn dh(&self, token: &DhToken) -> Result<[u8; MAXDHLEN], Error> {
let mut dh_out = [0u8; MAXDHLEN];
let (dh, key) = match (token, self.is_initiator()) {
(DhToken::Ee, _) => (&self.e, &self.re),
(DhToken::Ss, _) => (&self.s, &self.rs),
(DhToken::Se, true) | (DhToken::Es, false) => (&self.s, &self.re),
(DhToken::Es, true) | (DhToken::Se, false) => (&self.e, &self.rs),
};
if !(dh.is_on() && key.is_on()) {
return Err(StateProblem::MissingKeyMaterial.into());
}
dh.dh(&**key, &mut dh_out)?;
Ok(dh_out)
}
pub fn was_write_payload_encrypted(&self) -> bool {
self.symmetricstate.has_key()
}
pub fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
let checkpoint = self.symmetricstate.checkpoint();
match self._write_message(payload, message) {
Ok(res) => {
self.pattern_position += 1;
self.my_turn = false;
Ok(res)
},
Err(err) => {
self.symmetricstate.restore(checkpoint);
Err(err)
},
}
}
fn _write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
if !self.my_turn {
return Err(StateProblem::NotTurnToWrite.into());
} else if self.pattern_position >= self.message_patterns.len() {
return Err(StateProblem::HandshakeAlreadyFinished.into());
}
let mut byte_index = 0;
for token in self.message_patterns[self.pattern_position].iter() {
match token {
Token::E => {
if byte_index + self.e.pub_len() > message.len() {
return Err(Error::Input);
}
if !self.fixed_ephemeral {
self.e.generate(&mut *self.rng);
}
let pubkey = self.e.pubkey();
message[byte_index..byte_index + pubkey.len()].copy_from_slice(pubkey);
byte_index += pubkey.len();
self.symmetricstate.mix_hash(pubkey);
if self.params.handshake.is_psk() {
self.symmetricstate.mix_key(pubkey);
}
self.e.enable();
},
Token::S => {
if !self.s.is_on() {
return Err(StateProblem::MissingKeyMaterial.into());
} else if byte_index + self.s.pub_len() > message.len() {
return Err(Error::Input);
}
byte_index += self
.symmetricstate
.encrypt_and_mix_hash(self.s.pubkey(), &mut message[byte_index..])?;
},
Token::Psk(n) => match self.psks[*n as usize] {
Some(psk) => {
self.symmetricstate.mix_key_and_hash(&psk);
},
None => {
return Err(StateProblem::MissingPsk.into());
},
},
Token::Dh(t) => {
let dh_out = self.dh(t)?;
self.symmetricstate.mix_key(&dh_out[..self.dh_len()]);
},
#[cfg(feature = "hfs")]
Token::E1 => {
let kem = self.kem.as_mut().ok_or(Error::Input)?;
if kem.pub_len() > message.len() {
return Err(Error::Input);
}
kem.generate(&mut *self.rng);
byte_index += self
.symmetricstate
.encrypt_and_mix_hash(kem.pubkey(), &mut message[byte_index..])?;
},
#[cfg(feature = "hfs")]
Token::Ekem1 => {
let kem = self.kem.as_mut().unwrap();
let mut kem_output_buf = [0; MAXKEMSSLEN];
let mut ciphertext_buf = [0; MAXKEMCTLEN];
if kem.ciphertext_len() > message.len() {
return Err(Error::Input);
}
let kem_output = &mut kem_output_buf[..kem.shared_secret_len()];
let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()];
let pubkey = &self.kem_re.as_ref().unwrap()[..kem.pub_len()];
if kem.encapsulate(pubkey, kem_output, ciphertext).is_err() {
return Err(Error::Kem);
}
byte_index += self.symmetricstate.encrypt_and_mix_hash(
&ciphertext[..kem.ciphertext_len()],
&mut message[byte_index..],
)?;
self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]);
},
}
}
if byte_index + payload.len() + TAGLEN > message.len() {
return Err(Error::Input);
}
byte_index +=
self.symmetricstate.encrypt_and_mix_hash(payload, &mut message[byte_index..])?;
if byte_index > MAXMSGLEN {
return Err(Error::Input);
}
if self.pattern_position == (self.message_patterns.len() - 1) {
self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1);
}
Ok(byte_index)
}
pub fn read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> {
let checkpoint = self.symmetricstate.checkpoint();
match self._read_message(message, payload) {
Ok(res) => {
self.pattern_position += 1;
self.my_turn = true;
Ok(res)
},
Err(err) => {
self.symmetricstate.restore(checkpoint);
Err(err)
},
}
}
fn _read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> {
if message.len() > MAXMSGLEN {
return Err(Error::Input);
} else if self.my_turn {
return Err(StateProblem::NotTurnToRead.into());
} else if self.pattern_position >= self.message_patterns.len() {
return Err(StateProblem::HandshakeAlreadyFinished.into());
}
let last = self.pattern_position == (self.message_patterns.len() - 1);
let dh_len = self.dh_len();
let mut ptr = message;
for token in self.message_patterns[self.pattern_position].iter() {
match token {
Token::E => {
if ptr.len() < dh_len {
return Err(Error::Input);
}
self.re[..dh_len].copy_from_slice(&ptr[..dh_len]);
ptr = &ptr[dh_len..];
self.symmetricstate.mix_hash(&self.re[..dh_len]);
if self.params.handshake.is_psk() {
self.symmetricstate.mix_key(&self.re[..dh_len]);
}
self.re.enable();
},
Token::S => {
let data = if self.symmetricstate.has_key() {
if ptr.len() < dh_len + TAGLEN {
return Err(Error::Input);
}
let temp = &ptr[..dh_len + TAGLEN];
ptr = &ptr[dh_len + TAGLEN..];
temp
} else {
if ptr.len() < dh_len {
return Err(Error::Input);
}
let temp = &ptr[..dh_len];
ptr = &ptr[dh_len..];
temp
};
self.symmetricstate.decrypt_and_mix_hash(data, &mut self.rs[..dh_len])?;
self.rs.enable();
},
Token::Psk(n) => match self.psks[*n as usize] {
Some(psk) => {
self.symmetricstate.mix_key_and_hash(&psk);
},
None => {
return Err(StateProblem::MissingPsk.into());
},
},
Token::Dh(t) => {
let dh_out = self.dh(t)?;
self.symmetricstate.mix_key(&dh_out[..self.dh_len()]);
},
#[cfg(feature = "hfs")]
Token::E1 => {
let kem = self.kem.as_ref().ok_or(Error::Kem)?;
let read_len = if self.symmetricstate.has_key() {
kem.pub_len() + TAGLEN
} else {
kem.pub_len()
};
if ptr.len() < read_len {
return Err(Error::Input);
}
let mut kem_re = [0; MAXKEMPUBLEN];
self.symmetricstate
.decrypt_and_mix_hash(&ptr[..read_len], &mut kem_re[..kem.pub_len()])?;
self.kem_re = Some(kem_re);
ptr = &ptr[read_len..];
},
#[cfg(feature = "hfs")]
Token::Ekem1 => {
let kem = self.kem.as_ref().unwrap();
let read_len = if self.symmetricstate.has_key() {
kem.ciphertext_len() + TAGLEN
} else {
kem.ciphertext_len()
};
if ptr.len() < read_len {
return Err(Error::Input);
}
let mut ciphertext_buf = [0; MAXKEMCTLEN];
let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()];
self.symmetricstate.decrypt_and_mix_hash(&ptr[..read_len], ciphertext)?;
let mut kem_output_buf = [0; MAXKEMSSLEN];
let kem_output = &mut kem_output_buf[..kem.shared_secret_len()];
kem.decapsulate(ciphertext, kem_output).map_err(|_| Error::Kem)?;
self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]);
ptr = &ptr[read_len..];
},
}
}
self.symmetricstate.decrypt_and_mix_hash(ptr, payload)?;
if last {
self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1);
}
let payload_len =
if self.symmetricstate.has_key() { ptr.len() - TAGLEN } else { ptr.len() };
Ok(payload_len)
}
pub fn set_psk(&mut self, location: usize, key: &[u8]) -> Result<(), Error> {
if key.len() != PSKLEN || self.psks.len() <= location {
return Err(Error::Input);
}
let mut new_psk = [0u8; PSKLEN];
new_psk.copy_from_slice(key);
self.psks[location as usize] = Some(new_psk);
Ok(())
}
pub fn get_remote_static(&self) -> Option<&[u8]> {
self.rs.get().map(|rs| &rs[..self.dh_len()])
}
pub fn get_handshake_hash(&self) -> &[u8] {
self.symmetricstate.handshake_hash()
}
pub fn is_initiator(&self) -> bool {
self.initiator
}
pub fn is_handshake_finished(&self) -> bool {
self.pattern_position == self.message_patterns.len()
}
pub fn is_my_turn(&self) -> bool {
self.my_turn
}
#[cfg(feature = "risky-raw-split")]
pub fn dangerously_get_raw_split(&mut self) -> ([u8; CIPHERKEYLEN], [u8; CIPHERKEYLEN]) {
let mut output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
self.symmetricstate.split_raw(&mut output.0, &mut output.1);
(output.0[..CIPHERKEYLEN].try_into().unwrap(), output.1[..CIPHERKEYLEN].try_into().unwrap())
}
pub fn into_transport_mode(self) -> Result<TransportState, Error> {
self.try_into()
}
pub fn into_stateless_transport_mode(self) -> Result<StatelessTransportState, Error> {
self.try_into()
}
}
impl fmt::Debug for HandshakeState {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("HandshakeState").finish()
}
}