1use crate::{
2 cipherstate::CipherStates,
3 constants::{MAXDHLEN, MAXMSGLEN, TAGLEN},
4 error::{Error, StateProblem},
5 handshakestate::HandshakeState,
6 params::HandshakePattern,
7 utils::Toggle,
8};
9use std::{convert::TryFrom, fmt};
10
11pub struct TransportState {
17 cipherstates: CipherStates,
18 pattern: HandshakePattern,
19 dh_len: usize,
20 rs: Toggle<[u8; MAXDHLEN]>,
21 initiator: bool,
22}
23
24impl TransportState {
25 pub(crate) fn new(handshake: HandshakeState) -> Result<Self, Error> {
26 if !handshake.is_handshake_finished() {
27 return Err(StateProblem::HandshakeNotFinished.into());
28 }
29
30 let dh_len = handshake.dh_len();
31 let HandshakeState { cipherstates, params, rs, initiator, .. } = handshake;
32 let pattern = params.handshake.pattern;
33
34 Ok(TransportState { cipherstates, pattern, dh_len, rs, initiator })
35 }
36
37 pub fn get_remote_static(&self) -> Option<&[u8]> {
44 self.rs.get().map(|rs| &rs[..self.dh_len])
45 }
46
47 pub fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
57 if !self.initiator && self.pattern.is_oneway() {
58 return Err(StateProblem::OneWay.into());
59 } else if payload.len() + TAGLEN > MAXMSGLEN || payload.len() + TAGLEN > message.len() {
60 return Err(Error::Input);
61 }
62
63 let cipher =
64 if self.initiator { &mut self.cipherstates.0 } else { &mut self.cipherstates.1 };
65 cipher.encrypt(payload, message)
66 }
67
68 pub fn read_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
79 if payload.len() > MAXMSGLEN {
80 Err(Error::Input)
81 } else if self.initiator && self.pattern.is_oneway() {
82 Err(StateProblem::OneWay.into())
83 } else {
84 let cipher =
85 if self.initiator { &mut self.cipherstates.1 } else { &mut self.cipherstates.0 };
86 cipher.decrypt(payload, message)
87 }
88 }
89
90 pub fn rekey_outgoing(&mut self) {
95 if self.initiator {
96 self.cipherstates.rekey_initiator()
97 } else {
98 self.cipherstates.rekey_responder()
99 }
100 }
101
102 pub fn rekey_incoming(&mut self) {
107 if self.initiator {
108 self.cipherstates.rekey_responder()
109 } else {
110 self.cipherstates.rekey_initiator()
111 }
112 }
113
114 pub fn rekey_manually(&mut self, initiator: Option<&[u8]>, responder: Option<&[u8]>) {
116 if let Some(key) = initiator {
117 self.rekey_initiator_manually(key);
118 }
119 if let Some(key) = responder {
120 self.rekey_responder_manually(key);
121 }
122 }
123
124 pub fn rekey_initiator_manually(&mut self, key: &[u8]) {
126 self.cipherstates.rekey_initiator_manually(key)
127 }
128
129 pub fn rekey_responder_manually(&mut self, key: &[u8]) {
131 self.cipherstates.rekey_responder_manually(key)
132 }
133
134 pub fn set_receiving_nonce(&mut self, nonce: u64) {
136 if self.initiator {
137 self.cipherstates.1.set_nonce(nonce);
138 } else {
139 self.cipherstates.0.set_nonce(nonce);
140 }
141 }
142
143 pub fn receiving_nonce(&self) -> u64 {
149 if self.initiator {
150 self.cipherstates.1.nonce()
151 } else {
152 self.cipherstates.0.nonce()
153 }
154 }
155
156 pub fn sending_nonce(&self) -> u64 {
162 if self.initiator {
163 self.cipherstates.0.nonce()
164 } else {
165 self.cipherstates.1.nonce()
166 }
167 }
168
169 pub fn is_initiator(&self) -> bool {
171 self.initiator
172 }
173}
174
175impl fmt::Debug for TransportState {
176 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
177 fmt.debug_struct("TransportState").finish()
178 }
179}
180
181impl TryFrom<HandshakeState> for TransportState {
182 type Error = Error;
183
184 fn try_from(old: HandshakeState) -> Result<Self, Self::Error> {
185 TransportState::new(old)
186 }
187}