hickory_resolver/
hosts.rs

1//! Hosts result from a configuration of the system hosts file
2
3use std::collections::HashMap;
4use std::fs::File;
5use std::io;
6use std::net::IpAddr;
7use std::path::Path;
8use std::str::FromStr;
9use std::sync::Arc;
10
11use crate::proto::op::Query;
12use crate::proto::rr::rdata::PTR;
13use crate::proto::rr::{Name, RecordType};
14use crate::proto::rr::{RData, Record};
15use tracing::warn;
16
17use crate::dns_lru;
18use crate::lookup::Lookup;
19
20#[derive(Debug, Default)]
21struct LookupType {
22    /// represents the A record type
23    a: Option<Lookup>,
24    /// represents the AAAA record type
25    aaaa: Option<Lookup>,
26}
27
28/// Configuration for the local hosts file
29#[derive(Debug, Default)]
30pub struct Hosts {
31    /// Name -> RDatas map
32    by_name: HashMap<Name, LookupType>,
33}
34
35impl Hosts {
36    /// Creates a new configuration from the system hosts file,
37    /// only works for Windows and Unix-like OSes,
38    /// will return empty configuration on others
39    #[cfg(any(unix, windows))]
40    pub fn from_system() -> io::Result<Self> {
41        Self::from_file(hosts_path())
42    }
43
44    /// Creates a default configuration for non Windows or Unix-like OSes
45    #[cfg(not(any(unix, windows)))]
46    pub fn from_system() -> io::Result<Self> {
47        Ok(Hosts::default())
48    }
49
50    /// parse configuration from `path`
51    #[cfg(any(unix, windows))]
52    pub(crate) fn from_file(path: impl AsRef<Path>) -> io::Result<Self> {
53        let file = File::open(path)?;
54        let mut hosts = Self::default();
55        hosts.read_hosts_conf(file)?;
56        Ok(hosts)
57    }
58
59    /// Look up the addresses for the given host from the system hosts file.
60    pub fn lookup_static_host(&self, query: &Query) -> Option<Lookup> {
61        if self.by_name.is_empty() {
62            return None;
63        }
64
65        let mut name = query.name().clone();
66        name.set_fqdn(true);
67        match query.query_type() {
68            RecordType::A | RecordType::AAAA => {
69                let val = self.by_name.get(&name)?;
70
71                match query.query_type() {
72                    RecordType::A => val.a.clone(),
73                    RecordType::AAAA => val.aaaa.clone(),
74                    _ => None,
75                }
76            }
77            RecordType::PTR => {
78                let ip = name.parse_arpa_name().ok()?;
79
80                let ip_addr = ip.addr();
81                let records = self
82                    .by_name
83                    .iter()
84                    .filter(|(_, v)| match ip_addr {
85                        IpAddr::V4(ip) => match v.a.as_ref() {
86                            Some(lookup) => lookup
87                                .iter()
88                                .any(|r| r.ip_addr().map(|it| it == ip).unwrap_or_default()),
89                            None => false,
90                        },
91                        IpAddr::V6(ip) => match v.aaaa.as_ref() {
92                            Some(lookup) => lookup
93                                .iter()
94                                .any(|r| r.ip_addr().map(|it| it == ip).unwrap_or_default()),
95                            None => false,
96                        },
97                    })
98                    .map(|(n, _)| {
99                        Record::from_rdata(
100                            name.clone(),
101                            dns_lru::MAX_TTL,
102                            RData::PTR(PTR(n.clone())),
103                        )
104                    })
105                    .collect::<Arc<[Record]>>();
106
107                if records.is_empty() {
108                    return None;
109                }
110
111                Some(Lookup::new_with_max_ttl(query.clone(), records))
112            }
113            _ => None,
114        }
115    }
116
117    /// Insert a new Lookup for the associated `Name` and `RecordType`
118    pub fn insert(&mut self, mut name: Name, record_type: RecordType, lookup: Lookup) {
119        assert!(record_type == RecordType::A || record_type == RecordType::AAAA);
120
121        name.set_fqdn(true);
122        let lookup_type = self.by_name.entry(name.clone()).or_default();
123
124        let new_lookup = {
125            let old_lookup = match record_type {
126                RecordType::A => lookup_type.a.get_or_insert_with(|| {
127                    let query = Query::query(name.clone(), record_type);
128                    Lookup::new_with_max_ttl(query, Arc::from([]))
129                }),
130                RecordType::AAAA => lookup_type.aaaa.get_or_insert_with(|| {
131                    let query = Query::query(name.clone(), record_type);
132                    Lookup::new_with_max_ttl(query, Arc::from([]))
133                }),
134                _ => {
135                    tracing::warn!("unsupported IP type from Hosts file: {:#?}", record_type);
136                    return;
137                }
138            };
139
140            old_lookup.append(lookup)
141        };
142
143        // replace the appended version
144        match record_type {
145            RecordType::A => lookup_type.a = Some(new_lookup),
146            RecordType::AAAA => lookup_type.aaaa = Some(new_lookup),
147            _ => tracing::warn!("unsupported IP type from Hosts file"),
148        }
149    }
150
151    /// parse configuration from `src`
152    pub fn read_hosts_conf(&mut self, src: impl io::Read) -> io::Result<()> {
153        use std::io::{BufRead, BufReader};
154
155        // lines in the src should have the form `addr host1 host2 host3 ...`
156        // line starts with `#` will be regarded with comments and ignored,
157        // also empty line also will be ignored,
158        // if line only include `addr` without `host` will be ignored,
159        // the src will be parsed to map in the form `Name -> LookUp`.
160
161        for (line_index, line) in BufReader::new(src).lines().enumerate() {
162            let line = line?;
163
164            // Remove byte-order mark if present
165            let line = if line_index == 0 && line.starts_with('\u{feff}') {
166                // BOM is 3 bytes
167                &line[3..]
168            } else {
169                &line
170            };
171
172            // Remove comments from the line
173            let line = match line.split_once('#') {
174                Some((line, _)) => line,
175                None => line,
176            }
177            .trim();
178
179            if line.is_empty() {
180                continue;
181            }
182
183            let mut iter = line.split_whitespace();
184            let addr = match iter.next() {
185                Some(addr) => match IpAddr::from_str(addr) {
186                    Ok(addr) => RData::from(addr),
187                    Err(_) => {
188                        warn!("could not parse an IP from hosts file ({addr:?})");
189                        continue;
190                    }
191                },
192                None => continue,
193            };
194
195            for domain in iter {
196                let domain = domain.to_lowercase();
197                let Ok(mut name) = Name::from_str(&domain) else {
198                    continue;
199                };
200
201                name.set_fqdn(true);
202                let record = Record::from_rdata(name.clone(), dns_lru::MAX_TTL, addr.clone());
203                match addr {
204                    RData::A(..) => {
205                        let query = Query::query(name.clone(), RecordType::A);
206                        let lookup = Lookup::new_with_max_ttl(query, Arc::from([record]));
207                        self.insert(name.clone(), RecordType::A, lookup);
208                    }
209                    RData::AAAA(..) => {
210                        let query = Query::query(name.clone(), RecordType::AAAA);
211                        let lookup = Lookup::new_with_max_ttl(query, Arc::from([record]));
212                        self.insert(name.clone(), RecordType::AAAA, lookup);
213                    }
214                    _ => {
215                        warn!("unsupported IP type from Hosts file: {:#?}", addr);
216                        continue;
217                    }
218                };
219
220                // TODO: insert reverse lookup as well.
221            }
222        }
223
224        Ok(())
225    }
226}
227
228#[cfg(unix)]
229fn hosts_path() -> &'static str {
230    "/etc/hosts"
231}
232
233#[cfg(windows)]
234fn hosts_path() -> std::path::PathBuf {
235    let system_root =
236        std::env::var_os("SystemRoot").expect("Environment variable SystemRoot not found");
237    let system_root = Path::new(&system_root);
238    system_root.join("System32\\drivers\\etc\\hosts")
239}
240
241#[cfg(any(unix, windows))]
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use std::env;
246    use std::net::{Ipv4Addr, Ipv6Addr};
247
248    fn tests_dir() -> String {
249        let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
250        format! {"{server_path}/crates/resolver/tests"}
251    }
252
253    #[test]
254    fn test_read_hosts_conf() {
255        let path = format!("{}/hosts", tests_dir());
256        let hosts = Hosts::from_file(path).unwrap();
257
258        let name = Name::from_str("localhost.").unwrap();
259        let rdatas = hosts
260            .lookup_static_host(&Query::query(name.clone(), RecordType::A))
261            .unwrap()
262            .iter()
263            .map(ToOwned::to_owned)
264            .collect::<Vec<RData>>();
265
266        assert_eq!(rdatas, vec![RData::A(Ipv4Addr::LOCALHOST.into())]);
267
268        let rdatas = hosts
269            .lookup_static_host(&Query::query(name, RecordType::AAAA))
270            .unwrap()
271            .iter()
272            .map(ToOwned::to_owned)
273            .collect::<Vec<RData>>();
274
275        assert_eq!(
276            rdatas,
277            vec![RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into())]
278        );
279
280        let name = Name::from_str("broadcasthost").unwrap();
281        let rdatas = hosts
282            .lookup_static_host(&Query::query(name, RecordType::A))
283            .unwrap()
284            .iter()
285            .map(ToOwned::to_owned)
286            .collect::<Vec<RData>>();
287        assert_eq!(
288            rdatas,
289            vec![RData::A(Ipv4Addr::new(255, 255, 255, 255).into())]
290        );
291
292        let name = Name::from_str("example.com").unwrap();
293        let rdatas = hosts
294            .lookup_static_host(&Query::query(name, RecordType::A))
295            .unwrap()
296            .iter()
297            .map(ToOwned::to_owned)
298            .collect::<Vec<RData>>();
299        assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 102).into())]);
300
301        let name = Name::from_str("a.example.com").unwrap();
302        let rdatas = hosts
303            .lookup_static_host(&Query::query(name, RecordType::A))
304            .unwrap()
305            .iter()
306            .map(ToOwned::to_owned)
307            .collect::<Vec<RData>>();
308        assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 111).into())]);
309
310        let name = Name::from_str("b.example.com").unwrap();
311        let rdatas = hosts
312            .lookup_static_host(&Query::query(name, RecordType::A))
313            .unwrap()
314            .iter()
315            .map(ToOwned::to_owned)
316            .collect::<Vec<RData>>();
317        assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 111).into())]);
318
319        let name = Name::from_str("111.1.0.10.in-addr.arpa.").unwrap();
320        let mut rdatas = hosts
321            .lookup_static_host(&Query::query(name, RecordType::PTR))
322            .unwrap()
323            .iter()
324            .map(ToOwned::to_owned)
325            .collect::<Vec<RData>>();
326        rdatas.sort_by_key(|r| r.as_ptr().as_ref().map(|p| p.0.clone()));
327        assert_eq!(
328            rdatas,
329            vec![
330                RData::PTR(PTR("a.example.com.".parse().unwrap())),
331                RData::PTR(PTR("b.example.com.".parse().unwrap()))
332            ]
333        );
334
335        let name = Name::from_str(
336            "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
337        )
338        .unwrap();
339        let rdatas = hosts
340            .lookup_static_host(&Query::query(name, RecordType::PTR))
341            .unwrap()
342            .iter()
343            .map(ToOwned::to_owned)
344            .collect::<Vec<RData>>();
345        assert_eq!(
346            rdatas,
347            vec![RData::PTR(PTR("localhost.".parse().unwrap())),]
348        );
349    }
350}