1mod handshake;
2
3use futures_io::{AsyncRead, AsyncWrite};
4pub(crate) use handshake::{IoSession, MidHandshake};
5use rustls::{ConnectionCommon, SideData};
6use std::io::{self, IoSlice, Read, Write};
7use std::ops::{Deref, DerefMut};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11#[derive(Debug)]
12pub enum TlsState {
13 #[cfg(feature = "early-data")]
14 EarlyData(usize, Vec<u8>),
15 Stream,
16 ReadShutdown,
17 WriteShutdown,
18 FullyShutdown,
19}
20
21impl TlsState {
22 #[inline]
23 pub fn shutdown_read(&mut self) {
24 match *self {
25 TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
26 _ => *self = TlsState::ReadShutdown,
27 }
28 }
29
30 #[inline]
31 pub fn shutdown_write(&mut self) {
32 match *self {
33 TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
34 _ => *self = TlsState::WriteShutdown,
35 }
36 }
37
38 #[inline]
39 pub fn writeable(&self) -> bool {
40 !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
41 }
42
43 #[inline]
44 pub fn readable(&self) -> bool {
45 !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
46 }
47
48 #[inline]
49 #[cfg(feature = "early-data")]
50 pub fn is_early_data(&self) -> bool {
51 matches!(self, TlsState::EarlyData(..))
52 }
53
54 #[inline]
55 #[cfg(not(feature = "early-data"))]
56 pub const fn is_early_data(&self) -> bool {
57 false
58 }
59}
60
61pub struct Stream<'a, IO, C> {
62 pub io: &'a mut IO,
63 pub session: &'a mut C,
64 pub eof: bool,
65}
66
67impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
68where
69 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
70 SD: SideData,
71{
72 pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
73 Stream {
74 io,
75 session,
76 eof: false,
79 }
80 }
81
82 pub fn set_eof(mut self, eof: bool) -> Self {
83 self.eof = eof;
84 self
85 }
86
87 pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
88 Pin::new(self)
89 }
90
91 pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
92 struct Reader<'a, 'b, T> {
93 io: &'a mut T,
94 cx: &'a mut Context<'b>,
95 }
96
97 impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
98 #[inline]
99 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
100 match Pin::new(&mut self.io).poll_read(self.cx, buf) {
101 Poll::Ready(Ok(n)) => Ok(n),
102 Poll::Ready(Err(err)) => Err(err),
103 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
104 }
105 }
106 }
107
108 let mut reader = Reader { io: self.io, cx };
109
110 let n = match self.session.read_tls(&mut reader) {
111 Ok(n) => n,
112 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
113 Err(err) => return Poll::Ready(Err(err)),
114 };
115
116 let stats = self.session.process_new_packets().map_err(|err| {
117 let _ = self.write_io(cx);
121
122 io::Error::new(io::ErrorKind::InvalidData, err)
123 })?;
124
125 if stats.peer_has_closed() && self.session.is_handshaking() {
126 return Poll::Ready(Err(io::Error::new(
127 io::ErrorKind::UnexpectedEof,
128 "tls handshake alert",
129 )));
130 }
131
132 Poll::Ready(Ok(n))
133 }
134
135 pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
136 struct Writer<'a, 'b, T> {
137 io: &'a mut T,
138 cx: &'a mut Context<'b>,
139 }
140
141 impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
142 #[inline]
143 fn poll_with<U>(
144 &mut self,
145 f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
146 ) -> io::Result<U> {
147 match f(Pin::new(&mut self.io), self.cx) {
148 Poll::Ready(result) => result,
149 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
150 }
151 }
152 }
153
154 impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
155 #[inline]
156 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
157 self.poll_with(|io, cx| io.poll_write(cx, buf))
158 }
159
160 #[inline]
161 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
162 self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
163 }
164
165 fn flush(&mut self) -> io::Result<()> {
166 self.poll_with(|io, cx| io.poll_flush(cx))
167 }
168 }
169
170 let mut writer = Writer { io: self.io, cx };
171
172 match self.session.write_tls(&mut writer) {
173 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
174 result => Poll::Ready(result),
175 }
176 }
177
178 pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
179 let mut wrlen = 0;
180 let mut rdlen = 0;
181
182 loop {
183 let mut write_would_block = false;
184 let mut read_would_block = false;
185 let mut need_flush = false;
186
187 while self.session.wants_write() {
188 match self.write_io(cx) {
189 Poll::Ready(Ok(n)) => {
190 wrlen += n;
191 need_flush = true;
192 }
193 Poll::Pending => {
194 write_would_block = true;
195 break;
196 }
197 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
198 }
199 }
200
201 if need_flush {
202 match Pin::new(&mut self.io).poll_flush(cx) {
203 Poll::Ready(Ok(())) => (),
204 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
205 Poll::Pending => write_would_block = true,
206 }
207 }
208
209 while !self.eof && self.session.wants_read() {
210 match self.read_io(cx) {
211 Poll::Ready(Ok(0)) => self.eof = true,
212 Poll::Ready(Ok(n)) => rdlen += n,
213 Poll::Pending => {
214 read_would_block = true;
215 break;
216 }
217 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
218 }
219 }
220
221 return match (self.eof, self.session.is_handshaking()) {
222 (true, true) => {
223 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
224 Poll::Ready(Err(err))
225 }
226 (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
227 (_, true) if write_would_block || read_would_block => {
228 if rdlen != 0 || wrlen != 0 {
229 Poll::Ready(Ok((rdlen, wrlen)))
230 } else {
231 Poll::Pending
232 }
233 }
234 (..) => continue,
235 };
236 }
237 }
238}
239
240impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
241where
242 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
243 SD: SideData,
244{
245 fn poll_read(
246 mut self: Pin<&mut Self>,
247 cx: &mut Context<'_>,
248 buf: &mut [u8],
249 ) -> Poll<io::Result<usize>> {
250 let mut io_pending = false;
251
252 while !self.eof && self.session.wants_read() {
254 match self.read_io(cx) {
255 Poll::Ready(Ok(0)) => {
256 break;
257 }
258 Poll::Ready(Ok(_)) => (),
259 Poll::Pending => {
260 io_pending = true;
261 break;
262 }
263 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
264 }
265 }
266
267 match self.session.reader().read(buf) {
268 Ok(n) => Poll::Ready(Ok(n)),
277
278 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
280 if !io_pending {
281 cx.waker().wake_by_ref();
287 }
288
289 Poll::Pending
290 }
291
292 Err(err) => Poll::Ready(Err(err)),
293 }
294 }
295}
296
297impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C>
298where
299 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
300 SD: SideData,
301{
302 fn poll_write(
303 mut self: Pin<&mut Self>,
304 cx: &mut Context,
305 buf: &[u8],
306 ) -> Poll<io::Result<usize>> {
307 let mut pos = 0;
308
309 while pos != buf.len() {
310 let mut would_block = false;
311
312 match self.session.writer().write(&buf[pos..]) {
313 Ok(n) => pos += n,
314 Err(err) => return Poll::Ready(Err(err)),
315 };
316
317 while self.session.wants_write() {
318 match self.write_io(cx) {
319 Poll::Ready(Ok(0)) | Poll::Pending => {
320 would_block = true;
321 break;
322 }
323 Poll::Ready(Ok(_)) => (),
324 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
325 }
326 }
327
328 return match (pos, would_block) {
329 (0, true) => Poll::Pending,
330 (n, true) => Poll::Ready(Ok(n)),
331 (_, false) => continue,
332 };
333 }
334
335 Poll::Ready(Ok(pos))
336 }
337
338 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
339 self.session.writer().flush()?;
340 while self.session.wants_write() {
341 ready!(self.write_io(cx))?;
342 }
343 Pin::new(&mut self.io).poll_flush(cx)
344 }
345
346 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
347 while self.session.wants_write() {
348 ready!(self.write_io(cx))?;
349 }
350 Pin::new(&mut self.io).poll_close(cx)
351 }
352}
353
354pub struct SyncReadAdapter<'a, 'b, T> {
359 pub io: &'a mut T,
360 pub cx: &'a mut Context<'b>,
361}
362
363impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
364 #[inline]
365 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
366 match Pin::new(&mut self.io).poll_read(self.cx, buf) {
367 Poll::Ready(Ok(n)) => Ok(n),
368 Poll::Ready(Err(err)) => Err(err),
369 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
370 }
371 }
372}
373
374#[cfg(test)]
375mod test_stream;