1use crate::io::Output;
25use crate::{protocol::PublicKey, Error};
26use bytes::{Bytes, BytesMut};
27use futures::prelude::*;
28use futures::ready;
29use log::{debug, trace};
30use std::{
31 fmt, io,
32 pin::Pin,
33 task::{Context, Poll},
34};
35
36const MAX_NOISE_MSG_LEN: usize = 65535;
38const EXTRA_ENCRYPT_SPACE: usize = 1024;
40pub(crate) const MAX_FRAME_LEN: usize = MAX_NOISE_MSG_LEN - EXTRA_ENCRYPT_SPACE;
42static_assertions::const_assert! {
43 MAX_FRAME_LEN + EXTRA_ENCRYPT_SPACE <= MAX_NOISE_MSG_LEN
44}
45
46pub(crate) struct NoiseFramed<T, S> {
52 io: T,
53 session: S,
54 read_state: ReadState,
55 write_state: WriteState,
56 read_buffer: Vec<u8>,
57 write_buffer: Vec<u8>,
58 decrypt_buffer: BytesMut,
59}
60
61impl<T, S> fmt::Debug for NoiseFramed<T, S> {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 f.debug_struct("NoiseFramed")
64 .field("read_state", &self.read_state)
65 .field("write_state", &self.write_state)
66 .finish()
67 }
68}
69
70impl<T> NoiseFramed<T, snow::HandshakeState> {
71 pub(crate) fn new(io: T, state: snow::HandshakeState) -> Self {
73 NoiseFramed {
74 io,
75 session: state,
76 read_state: ReadState::Ready,
77 write_state: WriteState::Ready,
78 read_buffer: Vec::new(),
79 write_buffer: Vec::new(),
80 decrypt_buffer: BytesMut::new(),
81 }
82 }
83
84 pub(crate) fn is_initiator(&self) -> bool {
85 self.session.is_initiator()
86 }
87
88 pub(crate) fn is_responder(&self) -> bool {
89 !self.session.is_initiator()
90 }
91
92 pub(crate) fn into_transport(self) -> Result<(PublicKey, Output<T>), Error> {
101 let dh_remote_pubkey = self.session.get_remote_static().ok_or_else(|| {
102 Error::Io(io::Error::new(
103 io::ErrorKind::Other,
104 "expect key to always be present at end of XX session",
105 ))
106 })?;
107
108 let dh_remote_pubkey = PublicKey::from_slice(dh_remote_pubkey)?;
109
110 let io = NoiseFramed {
111 session: self.session.into_transport_mode()?,
112 io: self.io,
113 read_state: ReadState::Ready,
114 write_state: WriteState::Ready,
115 read_buffer: self.read_buffer,
116 write_buffer: self.write_buffer,
117 decrypt_buffer: self.decrypt_buffer,
118 };
119
120 Ok((dh_remote_pubkey, Output::new(io)))
121 }
122}
123
124#[derive(Debug)]
126enum ReadState {
127 Ready,
129 ReadLen { buf: [u8; 2], off: usize },
131 ReadData { len: usize, off: usize },
133 Eof(Result<(), ()>),
137 DecErr,
139}
140
141#[derive(Debug)]
143enum WriteState {
144 Ready,
146 WriteLen {
148 len: usize,
149 buf: [u8; 2],
150 off: usize,
151 },
152 WriteData { len: usize, off: usize },
154 Eof,
156 EncErr,
158}
159
160impl WriteState {
161 fn is_ready(&self) -> bool {
162 if let WriteState::Ready = self {
163 return true;
164 }
165 false
166 }
167}
168
169impl<T, S> futures::stream::Stream for NoiseFramed<T, S>
170where
171 T: AsyncRead + Unpin,
172 S: SessionState + Unpin,
173{
174 type Item = io::Result<Bytes>;
175
176 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
177 let this = Pin::into_inner(self);
178 loop {
179 trace!("read state: {:?}", this.read_state);
180 match this.read_state {
181 ReadState::Ready => {
182 this.read_state = ReadState::ReadLen {
183 buf: [0, 0],
184 off: 0,
185 };
186 }
187 ReadState::ReadLen { mut buf, mut off } => {
188 let n = match read_frame_len(&mut this.io, cx, &mut buf, &mut off) {
189 Poll::Ready(Ok(Some(n))) => n,
190 Poll::Ready(Ok(None)) => {
191 trace!("read: eof");
192 this.read_state = ReadState::Eof(Ok(()));
193 return Poll::Ready(None);
194 }
195 Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))),
196 Poll::Pending => {
197 this.read_state = ReadState::ReadLen { buf, off };
198 return Poll::Pending;
199 }
200 };
201 trace!("read: frame len = {}", n);
202 if n == 0 {
203 trace!("read: empty frame");
204 this.read_state = ReadState::Ready;
205 continue;
206 }
207 this.read_buffer.resize(usize::from(n), 0u8);
208 this.read_state = ReadState::ReadData {
209 len: usize::from(n),
210 off: 0,
211 }
212 }
213 ReadState::ReadData { len, ref mut off } => {
214 let n = {
215 let f =
216 Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off..len]);
217 match ready!(f) {
218 Ok(n) => n,
219 Err(e) => return Poll::Ready(Some(Err(e))),
220 }
221 };
222 trace!("read: {}/{} bytes", *off + n, len);
223 if n == 0 {
224 trace!("read: eof");
225 this.read_state = ReadState::Eof(Err(()));
226 return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
227 }
228 *off += n;
229 if len == *off {
230 trace!("read: decrypting {} bytes", len);
231 this.decrypt_buffer.resize(len, 0);
232 if let Ok(n) = this
233 .session
234 .read_message(&this.read_buffer, &mut this.decrypt_buffer)
235 {
236 this.decrypt_buffer.truncate(n);
237 trace!("read: payload len = {} bytes", n);
238 this.read_state = ReadState::Ready;
239 let view = this.decrypt_buffer.split().freeze();
244 return Poll::Ready(Some(Ok(view)));
245 } else {
246 debug!("read: decryption error");
247 this.read_state = ReadState::DecErr;
248 return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into())));
249 }
250 }
251 }
252 ReadState::Eof(Ok(())) => {
253 trace!("read: eof");
254 return Poll::Ready(None);
255 }
256 ReadState::Eof(Err(())) => {
257 trace!("read: eof (unexpected)");
258 return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
259 }
260 ReadState::DecErr => {
261 return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into())))
262 }
263 }
264 }
265 }
266}
267
268impl<T, S> futures::sink::Sink<&Vec<u8>> for NoiseFramed<T, S>
269where
270 T: AsyncWrite + Unpin,
271 S: SessionState + Unpin,
272{
273 type Error = io::Error;
274
275 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276 let this = Pin::into_inner(self);
277 loop {
278 trace!("write state {:?}", this.write_state);
279 match this.write_state {
280 WriteState::Ready => {
281 return Poll::Ready(Ok(()));
282 }
283 WriteState::WriteLen { len, buf, mut off } => {
284 trace!("write: frame len ({}, {:?}, {}/2)", len, buf, off);
285 match write_frame_len(&mut this.io, cx, &buf, &mut off) {
286 Poll::Ready(Ok(true)) => (),
287 Poll::Ready(Ok(false)) => {
288 trace!("write: eof");
289 this.write_state = WriteState::Eof;
290 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
291 }
292 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
293 Poll::Pending => {
294 this.write_state = WriteState::WriteLen { len, buf, off };
295 return Poll::Pending;
296 }
297 }
298 this.write_state = WriteState::WriteData { len, off: 0 }
299 }
300 WriteState::WriteData { len, ref mut off } => {
301 let n = {
302 let f =
303 Pin::new(&mut this.io).poll_write(cx, &this.write_buffer[*off..len]);
304 match ready!(f) {
305 Ok(n) => n,
306 Err(e) => return Poll::Ready(Err(e)),
307 }
308 };
309 if n == 0 {
310 trace!("write: eof");
311 this.write_state = WriteState::Eof;
312 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
313 }
314 *off += n;
315 trace!("write: {}/{} bytes written", *off, len);
316 if len == *off {
317 trace!("write: finished with {} bytes", len);
318 this.write_state = WriteState::Ready;
319 }
320 }
321 WriteState::Eof => {
322 trace!("write: eof");
323 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
324 }
325 WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())),
326 }
327 }
328 }
329
330 fn start_send(self: Pin<&mut Self>, frame: &Vec<u8>) -> Result<(), Self::Error> {
331 assert!(frame.len() <= MAX_FRAME_LEN);
332 let this = Pin::into_inner(self);
333 assert!(this.write_state.is_ready());
334
335 this.write_buffer
336 .resize(frame.len() + EXTRA_ENCRYPT_SPACE, 0u8);
337 match this
338 .session
339 .write_message(frame, &mut this.write_buffer[..])
340 {
341 Ok(n) => {
342 trace!("write: cipher text len = {} bytes", n);
343 this.write_buffer.truncate(n);
344 this.write_state = WriteState::WriteLen {
345 len: n,
346 buf: u16::to_be_bytes(n as u16),
347 off: 0,
348 };
349 Ok(())
350 }
351 Err(e) => {
352 log::error!("encryption error: {:?}", e);
353 this.write_state = WriteState::EncErr;
354 Err(io::ErrorKind::InvalidData.into())
355 }
356 }
357 }
358
359 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
360 ready!(self.as_mut().poll_ready(cx))?;
361 Pin::new(&mut self.io).poll_flush(cx)
362 }
363
364 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
365 ready!(self.as_mut().poll_flush(cx))?;
366 Pin::new(&mut self.io).poll_close(cx)
367 }
368}
369
370pub(crate) trait SessionState {
372 fn read_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error>;
373 fn write_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error>;
374}
375
376impl SessionState for snow::HandshakeState {
377 fn read_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error> {
378 self.read_message(msg, buf)
379 }
380
381 fn write_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error> {
382 self.write_message(msg, buf)
383 }
384}
385
386impl SessionState for snow::TransportState {
387 fn read_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error> {
388 self.read_message(msg, buf)
389 }
390
391 fn write_message(&mut self, msg: &[u8], buf: &mut [u8]) -> Result<usize, snow::Error> {
392 self.write_message(msg, buf)
393 }
394}
395
396fn read_frame_len<R: AsyncRead + Unpin>(
406 mut io: &mut R,
407 cx: &mut Context<'_>,
408 buf: &mut [u8; 2],
409 off: &mut usize,
410) -> Poll<io::Result<Option<u16>>> {
411 loop {
412 match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off..])) {
413 Ok(n) => {
414 if n == 0 {
415 return Poll::Ready(Ok(None));
416 }
417 *off += n;
418 if *off == 2 {
419 return Poll::Ready(Ok(Some(u16::from_be_bytes(*buf))));
420 }
421 }
422 Err(e) => {
423 return Poll::Ready(Err(e));
424 }
425 }
426 }
427}
428
429fn write_frame_len<W: AsyncWrite + Unpin>(
439 mut io: &mut W,
440 cx: &mut Context<'_>,
441 buf: &[u8; 2],
442 off: &mut usize,
443) -> Poll<io::Result<bool>> {
444 loop {
445 match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off..])) {
446 Ok(n) => {
447 if n == 0 {
448 return Poll::Ready(Ok(false));
449 }
450 *off += n;
451 if *off == 2 {
452 return Poll::Ready(Ok(true));
453 }
454 }
455 Err(e) => {
456 return Poll::Ready(Err(e));
457 }
458 }
459 }
460}