thrift/transport/
socket.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::convert::From;
19use std::io;
20use std::io::{ErrorKind, Read, Write};
21use std::net::{Shutdown, TcpStream, ToSocketAddrs};
22
23use super::{ReadHalf, TIoChannel, WriteHalf};
24use crate::{new_transport_error, TransportErrorKind};
25
26/// Bidirectional TCP/IP channel.
27///
28/// # Examples
29///
30/// Create a `TTcpChannel`.
31///
32/// ```no_run
33/// use std::io::{Read, Write};
34/// use thrift::transport::TTcpChannel;
35///
36/// let mut c = TTcpChannel::new();
37/// c.open("localhost:9090").unwrap();
38///
39/// let mut buf = vec![0u8; 4];
40/// c.read(&mut buf).unwrap();
41/// c.write(&vec![0, 1, 2]).unwrap();
42/// ```
43///
44/// Create a `TTcpChannel` by wrapping an existing `TcpStream`.
45///
46/// ```no_run
47/// use std::io::{Read, Write};
48/// use std::net::TcpStream;
49/// use thrift::transport::TTcpChannel;
50///
51/// let stream = TcpStream::connect("127.0.0.1:9189").unwrap();
52///
53/// // no need to call c.open() since we've already connected above
54/// let mut c = TTcpChannel::with_stream(stream);
55///
56/// let mut buf = vec![0u8; 4];
57/// c.read(&mut buf).unwrap();
58/// c.write(&vec![0, 1, 2]).unwrap();
59/// ```
60#[derive(Debug, Default)]
61pub struct TTcpChannel {
62    stream: Option<TcpStream>,
63}
64
65impl TTcpChannel {
66    /// Create an uninitialized `TTcpChannel`.
67    ///
68    /// The returned instance must be opened using `TTcpChannel::open(...)`
69    /// before it can be used.
70    pub fn new() -> TTcpChannel {
71        TTcpChannel { stream: None }
72    }
73
74    /// Create a `TTcpChannel` that wraps an existing `TcpStream`.
75    ///
76    /// The passed-in stream is assumed to have been opened before being wrapped
77    /// by the created `TTcpChannel` instance.
78    pub fn with_stream(stream: TcpStream) -> TTcpChannel {
79        TTcpChannel {
80            stream: Some(stream),
81        }
82    }
83
84    /// Connect to `remote_address`, which should implement `ToSocketAddrs` trait.
85    pub fn open<A: ToSocketAddrs>(&mut self, remote_address: A) -> crate::Result<()> {
86        if self.stream.is_some() {
87            Err(new_transport_error(
88                TransportErrorKind::AlreadyOpen,
89                "tcp connection previously opened",
90            ))
91        } else {
92            match TcpStream::connect(&remote_address) {
93                Ok(s) => {
94                    self.stream = Some(s);
95                    Ok(())
96                }
97                Err(e) => Err(From::from(e)),
98            }
99        }
100    }
101
102    /// Shut down this channel.
103    ///
104    /// Both send and receive halves are closed, and this instance can no
105    /// longer be used to communicate with another endpoint.
106    pub fn close(&mut self) -> crate::Result<()> {
107        self.if_set(|s| s.shutdown(Shutdown::Both))
108            .map_err(From::from)
109    }
110
111    fn if_set<F, T>(&mut self, mut stream_operation: F) -> io::Result<T>
112    where
113        F: FnMut(&mut TcpStream) -> io::Result<T>,
114    {
115        if let Some(ref mut s) = self.stream {
116            stream_operation(s)
117        } else {
118            Err(io::Error::new(
119                ErrorKind::NotConnected,
120                "tcp endpoint not connected",
121            ))
122        }
123    }
124}
125
126impl TIoChannel for TTcpChannel {
127    fn split(self) -> crate::Result<(ReadHalf<Self>, WriteHalf<Self>)>
128    where
129        Self: Sized,
130    {
131        let mut s = self;
132
133        s.stream
134            .as_mut()
135            .and_then(|s| s.try_clone().ok())
136            .map(|cloned| {
137                let read_half = ReadHalf::new(TTcpChannel {
138                    stream: s.stream.take(),
139                });
140                let write_half = WriteHalf::new(TTcpChannel {
141                    stream: Some(cloned),
142                });
143                (read_half, write_half)
144            })
145            .ok_or_else(|| {
146                new_transport_error(
147                    TransportErrorKind::Unknown,
148                    "cannot clone underlying tcp stream",
149                )
150            })
151    }
152}
153
154impl Read for TTcpChannel {
155    fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
156        self.if_set(|s| s.read(b))
157    }
158}
159
160impl Write for TTcpChannel {
161    fn write(&mut self, b: &[u8]) -> io::Result<usize> {
162        self.if_set(|s| s.write(b))
163    }
164
165    fn flush(&mut self) -> io::Result<()> {
166        self.if_set(|s| s.flush())
167    }
168}