litep2p/crypto/noise/
mod.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2// Copyright 2023 litep2p developers
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! Noise handshake and transport implementations.
23
24use crate::{
25    config::Role,
26    crypto::{ed25519::Keypair, PublicKey},
27    error::{NegotiationError, ParseError},
28    PeerId,
29};
30
31use bytes::{Buf, Bytes, BytesMut};
32use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
33use prost::Message;
34use snow::{Builder, HandshakeState, TransportState};
35
36use std::{
37    fmt, io,
38    pin::Pin,
39    task::{Context, Poll},
40};
41
42mod protocol;
43mod x25519_spec;
44
45mod handshake_schema {
46    include!(concat!(env!("OUT_DIR"), "/noise.rs"));
47}
48
49/// Noise parameters.
50const NOISE_PARAMETERS: &str = "Noise_XX_25519_ChaChaPoly_SHA256";
51
52/// Prefix of static key signatures for domain separation.
53pub(crate) const STATIC_KEY_DOMAIN: &str = "noise-libp2p-static-key:";
54
55/// Maximum Noise message size.
56const MAX_NOISE_MSG_LEN: usize = 65536;
57
58/// Space given to the encryption buffer to hold key material.
59const NOISE_EXTRA_ENCRYPT_SPACE: usize = 16;
60
61/// Max read ahead factor for the noise socket.
62///
63/// Specifies how many multiples of `MAX_NOISE_MESSAGE_LEN` are read from the socket
64/// using one call to `poll_read()`.
65pub(crate) const MAX_READ_AHEAD_FACTOR: usize = 5;
66
67/// Maximum write buffer size.
68pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 2;
69
70/// Max. length for Noise protocol message payloads.
71pub const MAX_FRAME_LEN: usize = MAX_NOISE_MSG_LEN - NOISE_EXTRA_ENCRYPT_SPACE;
72
73/// Logging target for the file.
74const LOG_TARGET: &str = "litep2p::crypto::noise";
75
76#[derive(Debug)]
77#[allow(clippy::large_enum_variant)]
78enum NoiseState {
79    Handshake(HandshakeState),
80    Transport(TransportState),
81}
82
83pub struct NoiseContext {
84    keypair: snow::Keypair,
85    noise: NoiseState,
86    role: Role,
87    pub payload: Vec<u8>,
88}
89
90impl fmt::Debug for NoiseContext {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        f.debug_struct("NoiseContext")
93            .field("public", &self.noise)
94            .field("payload", &self.payload)
95            .field("role", &self.role)
96            .finish()
97    }
98}
99
100impl NoiseContext {
101    /// Assemble Noise payload and return [`NoiseContext`].
102    fn assemble(
103        noise: snow::HandshakeState,
104        keypair: snow::Keypair,
105        id_keys: &Keypair,
106        role: Role,
107    ) -> Result<Self, NegotiationError> {
108        let noise_payload = handshake_schema::NoiseHandshakePayload {
109            identity_key: Some(PublicKey::Ed25519(id_keys.public()).to_protobuf_encoding()),
110            identity_sig: Some(
111                id_keys.sign(&[STATIC_KEY_DOMAIN.as_bytes(), keypair.public.as_ref()].concat()),
112            ),
113            ..Default::default()
114        };
115
116        let mut payload = Vec::with_capacity(noise_payload.encoded_len());
117        noise_payload.encode(&mut payload).map_err(ParseError::from)?;
118
119        Ok(Self {
120            noise: NoiseState::Handshake(noise),
121            keypair,
122            payload,
123            role,
124        })
125    }
126
127    pub fn new(keypair: &Keypair, role: Role) -> Result<Self, NegotiationError> {
128        tracing::trace!(target: LOG_TARGET, ?role, "create new noise configuration");
129
130        let builder: Builder<'_> = Builder::with_resolver(
131            NOISE_PARAMETERS.parse().expect("qed; Valid noise pattern"),
132            Box::new(protocol::Resolver),
133        );
134
135        let dh_keypair = builder.generate_keypair()?;
136        let static_key = &dh_keypair.private;
137
138        let noise = match role {
139            Role::Dialer => builder.local_private_key(static_key).build_initiator()?,
140            Role::Listener => builder.local_private_key(static_key).build_responder()?,
141        };
142
143        Self::assemble(noise, dh_keypair, keypair, role)
144    }
145
146    /// Create new [`NoiseContext`] with prologue.
147    #[cfg(feature = "webrtc")]
148    pub fn with_prologue(id_keys: &Keypair, prologue: Vec<u8>) -> Result<Self, NegotiationError> {
149        let noise: Builder<'_> = Builder::with_resolver(
150            NOISE_PARAMETERS.parse().expect("qed; Valid noise pattern"),
151            Box::new(protocol::Resolver),
152        );
153
154        let keypair = noise.generate_keypair()?;
155
156        let noise = noise
157            .local_private_key(&keypair.private)
158            .prologue(&prologue)
159            .build_initiator()?;
160
161        Self::assemble(noise, keypair, id_keys, Role::Dialer)
162    }
163
164    /// Get remote public key from the received Noise payload.
165    #[cfg(feature = "webrtc")]
166    pub fn get_remote_public_key(&mut self, reply: &[u8]) -> Result<PublicKey, NegotiationError> {
167        let (len_slice, reply) = reply.split_at(2);
168        let len = u16::from_be_bytes(
169            len_slice
170                .try_into()
171                .map_err(|_| NegotiationError::ParseError(ParseError::InvalidPublicKey))?,
172        ) as usize;
173
174        let mut buffer = vec![0u8; len];
175
176        let NoiseState::Handshake(ref mut noise) = self.noise else {
177            tracing::error!(target: LOG_TARGET, "invalid state to read the second handshake message");
178            debug_assert!(false);
179            return Err(NegotiationError::StateMismatch);
180        };
181
182        let res = noise.read_message(reply, &mut buffer)?;
183        buffer.truncate(res);
184
185        let payload = handshake_schema::NoiseHandshakePayload::decode(buffer.as_slice())
186            .map_err(|err| NegotiationError::ParseError(err.into()))?;
187
188        let identity = payload.identity_key.ok_or(NegotiationError::PeerIdMissing)?;
189        PublicKey::from_protobuf_encoding(&identity).map_err(|err| err.into())
190    }
191
192    /// Get first message.
193    ///
194    /// Listener only sends one message (the payload)
195    pub fn first_message(&mut self, role: Role) -> Result<Vec<u8>, NegotiationError> {
196        match role {
197            Role::Dialer => {
198                tracing::trace!(target: LOG_TARGET, "get noise dialer first message");
199
200                let NoiseState::Handshake(ref mut noise) = self.noise else {
201                    tracing::error!(target: LOG_TARGET, "invalid state to read the first handshake message");
202                    debug_assert!(false);
203                    return Err(NegotiationError::StateMismatch);
204                };
205
206                let mut buffer = vec![0u8; 256];
207                let nwritten = noise.write_message(&[], &mut buffer)?;
208                buffer.truncate(nwritten);
209
210                let size = nwritten as u16;
211                let mut size = size.to_be_bytes().to_vec();
212                size.append(&mut buffer);
213
214                Ok(size)
215            }
216            Role::Listener => self.second_message(),
217        }
218    }
219
220    /// Get second message.
221    ///
222    /// Only the dialer sends the second message.
223    pub fn second_message(&mut self) -> Result<Vec<u8>, NegotiationError> {
224        tracing::trace!(target: LOG_TARGET, "get noise paylod message");
225
226        let NoiseState::Handshake(ref mut noise) = self.noise else {
227            tracing::error!(target: LOG_TARGET, "invalid state to read the first handshake message");
228            debug_assert!(false);
229            return Err(NegotiationError::StateMismatch);
230        };
231
232        let mut buffer = vec![0u8; 2048];
233        let nwritten = noise.write_message(&self.payload, &mut buffer)?;
234        buffer.truncate(nwritten);
235
236        let size = nwritten as u16;
237        let mut size = size.to_be_bytes().to_vec();
238        size.append(&mut buffer);
239
240        Ok(size)
241    }
242
243    /// Read handshake message.
244    async fn read_handshake_message<T: AsyncRead + AsyncWrite + Unpin>(
245        &mut self,
246        io: &mut T,
247    ) -> Result<Bytes, NegotiationError> {
248        let mut size = BytesMut::zeroed(2);
249        io.read_exact(&mut size).await?;
250        let size = size.get_u16();
251
252        let mut message = BytesMut::zeroed(size as usize);
253        io.read_exact(&mut message).await?;
254
255        let mut out = BytesMut::new();
256        out.resize(message.len() + 200, 0u8); // TODO: correct overhead
257
258        let NoiseState::Handshake(ref mut noise) = self.noise else {
259            tracing::error!(target: LOG_TARGET, "invalid state to read handshake message");
260            debug_assert!(false);
261            return Err(NegotiationError::StateMismatch);
262        };
263
264        let nread = noise.read_message(&message, &mut out)?;
265        out.truncate(nread);
266
267        Ok(out.freeze())
268    }
269
270    fn read_message(&mut self, message: &[u8], out: &mut [u8]) -> Result<usize, snow::Error> {
271        match self.noise {
272            NoiseState::Handshake(ref mut noise) => noise.read_message(message, out),
273            NoiseState::Transport(ref mut noise) => noise.read_message(message, out),
274        }
275    }
276
277    fn write_message(&mut self, message: &[u8], out: &mut [u8]) -> Result<usize, snow::Error> {
278        match self.noise {
279            NoiseState::Handshake(ref mut noise) => noise.write_message(message, out),
280            NoiseState::Transport(ref mut noise) => noise.write_message(message, out),
281        }
282    }
283
284    /// Convert Noise into transport mode.
285    fn into_transport(self) -> Result<NoiseContext, NegotiationError> {
286        let transport = match self.noise {
287            NoiseState::Handshake(noise) => noise.into_transport_mode()?,
288            NoiseState::Transport(_) => return Err(NegotiationError::StateMismatch),
289        };
290
291        Ok(NoiseContext {
292            keypair: self.keypair,
293            payload: self.payload,
294            role: self.role,
295            noise: NoiseState::Transport(transport),
296        })
297    }
298}
299
300enum ReadState {
301    ReadData {
302        max_read: usize,
303    },
304    ReadFrameLen,
305    ProcessNextFrame {
306        pending: Option<Vec<u8>>,
307        offset: usize,
308        size: usize,
309        frame_size: usize,
310    },
311}
312
313enum WriteState {
314    Ready {
315        offset: usize,
316        size: usize,
317        encrypted_size: usize,
318    },
319    WriteFrame {
320        offset: usize,
321        size: usize,
322        encrypted_size: usize,
323    },
324}
325
326pub struct NoiseSocket<S: AsyncRead + AsyncWrite + Unpin> {
327    io: S,
328    noise: NoiseContext,
329    current_frame_size: Option<usize>,
330    write_state: WriteState,
331    encrypt_buffer: Vec<u8>,
332    offset: usize,
333    nread: usize,
334    read_state: ReadState,
335    read_buffer: Vec<u8>,
336    canonical_max_read: usize,
337    decrypt_buffer: Option<Vec<u8>>,
338}
339
340impl<S: AsyncRead + AsyncWrite + Unpin> NoiseSocket<S> {
341    fn new(
342        io: S,
343        noise: NoiseContext,
344        max_read_ahead_factor: usize,
345        max_write_buffer_size: usize,
346    ) -> Self {
347        Self {
348            io,
349            noise,
350            read_buffer: vec![
351                0u8;
352                max_read_ahead_factor * MAX_NOISE_MSG_LEN + (2 + MAX_NOISE_MSG_LEN)
353            ],
354            nread: 0usize,
355            offset: 0usize,
356            current_frame_size: None,
357            write_state: WriteState::Ready {
358                offset: 0usize,
359                size: 0usize,
360                encrypted_size: 0usize,
361            },
362            encrypt_buffer: vec![0u8; max_write_buffer_size * (MAX_NOISE_MSG_LEN + 2)],
363            decrypt_buffer: Some(vec![0u8; MAX_FRAME_LEN]),
364            read_state: ReadState::ReadData {
365                max_read: max_read_ahead_factor * MAX_NOISE_MSG_LEN,
366            },
367            canonical_max_read: max_read_ahead_factor * MAX_NOISE_MSG_LEN,
368        }
369    }
370
371    fn reset_read_state(&mut self, remaining: usize) {
372        match remaining {
373            0 => {
374                self.nread = 0;
375            }
376            1 => {
377                self.read_buffer[0] = self.read_buffer[self.nread - 1];
378                self.nread = 1;
379            }
380            _ => panic!("invalid state"),
381        }
382
383        self.offset = 0;
384        self.read_state = ReadState::ReadData {
385            max_read: self.canonical_max_read,
386        };
387    }
388}
389
390impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for NoiseSocket<S> {
391    fn poll_read(
392        self: Pin<&mut Self>,
393        cx: &mut Context<'_>,
394        buf: &mut [u8],
395    ) -> Poll<io::Result<usize>> {
396        let this = Pin::into_inner(self);
397
398        loop {
399            match this.read_state {
400                ReadState::ReadData { max_read } => {
401                    let nread = match Pin::new(&mut this.io)
402                        .poll_read(cx, &mut this.read_buffer[this.nread..max_read])
403                    {
404                        Poll::Pending => return Poll::Pending,
405                        Poll::Ready(Err(error)) => return Poll::Ready(Err(error)),
406                        Poll::Ready(Ok(nread)) => match nread == 0 {
407                            true => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
408                            false => nread,
409                        },
410                    };
411
412                    tracing::trace!(target: LOG_TARGET, ?nread, "read data from socket");
413
414                    this.nread += nread;
415                    this.read_state = ReadState::ReadFrameLen;
416                }
417                ReadState::ReadFrameLen => {
418                    let mut remaining = match this.nread.checked_sub(this.offset) {
419                        Some(remaining) => remaining,
420                        None => {
421                            tracing::error!(target: LOG_TARGET, "offset is larger than the number of bytes read");
422                            return Poll::Ready(Err(io::ErrorKind::PermissionDenied.into()));
423                        }
424                    };
425
426                    if remaining < 2 {
427                        tracing::trace!(target: LOG_TARGET, "reset read buffer");
428                        this.reset_read_state(remaining);
429                        continue;
430                    }
431
432                    // get frame size, either from current or previous iteration
433                    let frame_size = match this.current_frame_size.take() {
434                        Some(frame_size) => frame_size,
435                        None => {
436                            let frame_size = (this.read_buffer[this.offset] as u16) << 8
437                                | this.read_buffer[this.offset + 1] as u16;
438                            this.offset += 2;
439                            remaining -= 2;
440                            frame_size as usize
441                        }
442                    };
443
444                    tracing::trace!(target: LOG_TARGET, "current frame size = {frame_size}");
445
446                    if remaining < frame_size {
447                        // `read_buffer` can fit the full frame size.
448                        if this.nread + frame_size < this.canonical_max_read {
449                            tracing::trace!(
450                                target: LOG_TARGET,
451                                max_size = ?this.canonical_max_read,
452                                next_frame_size = ?(this.nread + frame_size),
453                                "read buffer can fit the full frame",
454                            );
455
456                            this.current_frame_size = Some(frame_size);
457                            this.read_state = ReadState::ReadData {
458                                max_read: this.canonical_max_read,
459                            };
460                            continue;
461                        }
462
463                        tracing::trace!(target: LOG_TARGET, "use auxiliary buffer extension");
464
465                        // use the auxiliary memory at the end of the read buffer for reading the
466                        // frame
467                        this.current_frame_size = Some(frame_size);
468                        this.read_state = ReadState::ReadData {
469                            max_read: this.nread + frame_size - remaining,
470                        };
471                        continue;
472                    }
473
474                    if frame_size <= NOISE_EXTRA_ENCRYPT_SPACE {
475                        tracing::error!(
476                            target: LOG_TARGET,
477                            ?frame_size,
478                            max_size = ?NOISE_EXTRA_ENCRYPT_SPACE,
479                            "invalid frame size",
480                        );
481                        return Poll::Ready(Err(io::ErrorKind::InvalidData.into()));
482                    }
483
484                    this.current_frame_size = Some(frame_size);
485                    this.read_state = ReadState::ProcessNextFrame {
486                        pending: None,
487                        offset: 0usize,
488                        size: 0usize,
489                        frame_size: 0usize,
490                    };
491                }
492                ReadState::ProcessNextFrame {
493                    ref mut pending,
494                    offset,
495                    size,
496                    frame_size,
497                } => match pending.take() {
498                    Some(pending) => match buf.len() >= pending[offset..size].len() {
499                        true => {
500                            let copy_size = pending[offset..size].len();
501                            buf[..copy_size].copy_from_slice(&pending[offset..copy_size + offset]);
502
503                            this.read_state = ReadState::ReadFrameLen;
504                            this.decrypt_buffer = Some(pending);
505                            this.offset += frame_size;
506                            return Poll::Ready(Ok(copy_size));
507                        }
508                        false => {
509                            buf.copy_from_slice(&pending[offset..buf.len() + offset]);
510
511                            this.read_state = ReadState::ProcessNextFrame {
512                                pending: Some(pending),
513                                offset: offset + buf.len(),
514                                size,
515                                frame_size,
516                            };
517                            return Poll::Ready(Ok(buf.len()));
518                        }
519                    },
520                    None => {
521                        let frame_size =
522                            this.current_frame_size.take().expect("`frame_size` to exist");
523
524                        match buf.len() >= frame_size - NOISE_EXTRA_ENCRYPT_SPACE {
525                            true => match this.noise.read_message(
526                                &this.read_buffer[this.offset..this.offset + frame_size],
527                                buf,
528                            ) {
529                                Err(error) => {
530                                    tracing::error!(target: LOG_TARGET, ?error, "failed to decrypt message");
531                                    return Poll::Ready(Err(io::ErrorKind::InvalidData.into()));
532                                }
533                                Ok(nread) => {
534                                    this.offset += frame_size;
535                                    this.read_state = ReadState::ReadFrameLen;
536                                    return Poll::Ready(Ok(nread));
537                                }
538                            },
539                            false => {
540                                let mut buffer =
541                                    this.decrypt_buffer.take().expect("buffer to exist");
542
543                                match this.noise.read_message(
544                                    &this.read_buffer[this.offset..this.offset + frame_size],
545                                    &mut buffer,
546                                ) {
547                                    Err(error) => {
548                                        tracing::error!(target: LOG_TARGET, ?error, "failed to decrypt message");
549                                        return Poll::Ready(Err(io::ErrorKind::InvalidData.into()));
550                                    }
551                                    Ok(nread) => {
552                                        buf.copy_from_slice(&buffer[..buf.len()]);
553                                        this.read_state = ReadState::ProcessNextFrame {
554                                            pending: Some(buffer),
555                                            offset: buf.len(),
556                                            size: nread,
557                                            frame_size,
558                                        };
559                                        return Poll::Ready(Ok(buf.len()));
560                                    }
561                                }
562                            }
563                        }
564                    }
565                },
566            }
567        }
568    }
569}
570
571impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NoiseSocket<S> {
572    fn poll_write(
573        self: Pin<&mut Self>,
574        cx: &mut Context<'_>,
575        buf: &[u8],
576    ) -> Poll<io::Result<usize>> {
577        let this = Pin::into_inner(self);
578        let mut chunks = buf.chunks(MAX_FRAME_LEN).peekable();
579
580        loop {
581            match this.write_state {
582                WriteState::Ready {
583                    offset,
584                    size,
585                    encrypted_size,
586                } => {
587                    let Some(chunk) = chunks.next() else {
588                        break;
589                    };
590
591                    match this.noise.write_message(chunk, &mut this.encrypt_buffer[offset + 2..]) {
592                        Err(error) => {
593                            tracing::error!(target: LOG_TARGET, ?error, "failed to encrypt message");
594                            return Poll::Ready(Err(io::ErrorKind::InvalidData.into()));
595                        }
596                        Ok(nwritten) => {
597                            this.encrypt_buffer[offset] = (nwritten >> 8) as u8;
598                            this.encrypt_buffer[offset + 1] = (nwritten & 0xff) as u8;
599
600                            if let Some(next_chunk) = chunks.peek() {
601                                if next_chunk.len() + NOISE_EXTRA_ENCRYPT_SPACE + 2
602                                    <= this.encrypt_buffer[offset + nwritten + 2..].len()
603                                {
604                                    this.write_state = WriteState::Ready {
605                                        offset: offset + nwritten + 2,
606                                        size: size + chunk.len(),
607                                        encrypted_size: encrypted_size + nwritten + 2,
608                                    };
609                                    continue;
610                                }
611                            }
612
613                            this.write_state = WriteState::WriteFrame {
614                                offset: 0usize,
615                                size: size + chunk.len(),
616                                encrypted_size: encrypted_size + nwritten + 2,
617                            };
618                        }
619                    }
620                }
621                WriteState::WriteFrame {
622                    ref mut offset,
623                    size,
624                    encrypted_size,
625                } => loop {
626                    match futures::ready!(Pin::new(&mut this.io)
627                        .poll_write(cx, &this.encrypt_buffer[*offset..encrypted_size]))
628                    {
629                        Ok(nwritten) => {
630                            *offset += nwritten;
631
632                            if offset == &encrypted_size {
633                                this.write_state = WriteState::Ready {
634                                    offset: 0usize,
635                                    size: 0usize,
636                                    encrypted_size: 0usize,
637                                };
638                                return Poll::Ready(Ok(size));
639                            }
640                        }
641                        Err(error) => return Poll::Ready(Err(error)),
642                    }
643                },
644            }
645        }
646
647        Poll::Ready(Ok(0))
648    }
649
650    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
651        Pin::new(&mut self.io).poll_flush(cx)
652    }
653
654    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
655        Pin::new(&mut self.io).poll_close(cx)
656    }
657}
658
659/// Try to parse `PeerId` from received `NoiseHandshakePayload`
660fn parse_peer_id(buf: &[u8]) -> Result<PeerId, NegotiationError> {
661    match handshake_schema::NoiseHandshakePayload::decode(buf) {
662        Ok(payload) => {
663            let identity = payload.identity_key.ok_or(NegotiationError::PeerIdMissing)?;
664
665            let public_key = PublicKey::from_protobuf_encoding(&identity)?;
666            Ok(PeerId::from_public_key(&public_key))
667        }
668        Err(err) => Err(ParseError::from(err).into()),
669    }
670}
671
672/// Perform Noise handshake.
673pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
674    mut io: S,
675    keypair: &Keypair,
676    role: Role,
677    max_read_ahead_factor: usize,
678    max_write_buffer_size: usize,
679) -> Result<(NoiseSocket<S>, PeerId), NegotiationError> {
680    tracing::debug!(target: LOG_TARGET, ?role, "start noise handshake");
681
682    let mut noise = NoiseContext::new(keypair, role)?;
683    let peer = match role {
684        Role::Dialer => {
685            // write initial message
686            let first_message = noise.first_message(Role::Dialer)?;
687            let _ = io.write(&first_message).await?;
688            io.flush().await?;
689
690            // read back response which contains the remote peer id
691            let message = noise.read_handshake_message(&mut io).await?;
692
693            // send the final message which contains local peer id
694            let second_message = noise.second_message()?;
695            let _ = io.write(&second_message).await?;
696            io.flush().await?;
697
698            parse_peer_id(&message)?
699        }
700        Role::Listener => {
701            // read remote's first message
702            let _ = noise.read_handshake_message(&mut io).await?;
703
704            // send local peer id.
705            let second_message = noise.second_message()?;
706            let _ = io.write(&second_message).await?;
707            io.flush().await?;
708
709            // read remote's second message which contains their peer id
710            let message = noise.read_handshake_message(&mut io).await?;
711            parse_peer_id(&message)?
712        }
713    };
714
715    Ok((
716        NoiseSocket::new(
717            io,
718            noise.into_transport()?,
719            max_read_ahead_factor,
720            max_write_buffer_size,
721        ),
722        peer,
723    ))
724}
725
726// TODO: add more tests
727#[cfg(test)]
728mod tests {
729    use super::*;
730    use std::net::SocketAddr;
731    use tokio::net::{TcpListener, TcpStream};
732    use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
733
734    #[tokio::test]
735    async fn noise_handshake() {
736        let _ = tracing_subscriber::fmt()
737            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
738            .try_init();
739
740        let keypair1 = Keypair::generate();
741        let keypair2 = Keypair::generate();
742
743        let peer1_id = PeerId::from_public_key(&keypair1.public().into());
744        let peer2_id = PeerId::from_public_key(&keypair2.public().into());
745
746        let listener = TcpListener::bind("[::1]:0".parse::<SocketAddr>().unwrap()).await.unwrap();
747
748        let (stream1, stream2) = tokio::join!(
749            TcpStream::connect(listener.local_addr().unwrap()),
750            listener.accept()
751        );
752        let (io1, io2) = {
753            let io1 = TokioAsyncReadCompatExt::compat(stream1.unwrap()).into_inner();
754            let io1 = Box::new(TokioAsyncWriteCompatExt::compat_write(io1));
755            let io2 = TokioAsyncReadCompatExt::compat(stream2.unwrap().0).into_inner();
756            let io2 = Box::new(TokioAsyncWriteCompatExt::compat_write(io2));
757
758            (io1, io2)
759        };
760
761        let (res1, res2) = tokio::join!(
762            handshake(
763                io1,
764                &keypair1,
765                Role::Dialer,
766                MAX_READ_AHEAD_FACTOR,
767                MAX_WRITE_BUFFER_SIZE
768            ),
769            handshake(
770                io2,
771                &keypair2,
772                Role::Listener,
773                MAX_READ_AHEAD_FACTOR,
774                MAX_WRITE_BUFFER_SIZE
775            )
776        );
777        let (mut res1, mut res2) = (res1.unwrap(), res2.unwrap());
778
779        assert_eq!(res1.1, peer2_id);
780        assert_eq!(res2.1, peer1_id);
781
782        // verify the connection works by reading a string
783        let mut buf = vec![0u8; 512];
784        let sent = res1.0.write(b"hello, world").await.unwrap();
785        res2.0.read_exact(&mut buf[..sent]).await.unwrap();
786
787        assert_eq!(std::str::from_utf8(&buf[..sent]), Ok("hello, world"));
788    }
789
790    #[test]
791    fn invalid_peer_id_schema() {
792        match parse_peer_id(&vec![1, 2, 3, 4]).unwrap_err() {
793            NegotiationError::ParseError(_) => {}
794            _ => panic!("invalid error"),
795        }
796    }
797}