1use std::collections::HashMap;
4use std::fs::File;
5use std::io;
6use std::net::IpAddr;
7use std::path::Path;
8use std::str::FromStr;
9use std::sync::Arc;
10
11use crate::proto::op::Query;
12use crate::proto::rr::rdata::PTR;
13use crate::proto::rr::{Name, RecordType};
14use crate::proto::rr::{RData, Record};
15use tracing::warn;
16
17use crate::dns_lru;
18use crate::lookup::Lookup;
19
20#[derive(Debug, Default)]
21struct LookupType {
22 a: Option<Lookup>,
24 aaaa: Option<Lookup>,
26}
27
28#[derive(Debug, Default)]
30pub struct Hosts {
31 by_name: HashMap<Name, LookupType>,
33}
34
35impl Hosts {
36 #[cfg(any(unix, windows))]
40 pub fn from_system() -> io::Result<Self> {
41 Self::from_file(hosts_path())
42 }
43
44 #[cfg(not(any(unix, windows)))]
46 pub fn from_system() -> io::Result<Self> {
47 Ok(Hosts::default())
48 }
49
50 #[cfg(any(unix, windows))]
52 pub(crate) fn from_file(path: impl AsRef<Path>) -> io::Result<Self> {
53 let file = File::open(path)?;
54 let mut hosts = Self::default();
55 hosts.read_hosts_conf(file)?;
56 Ok(hosts)
57 }
58
59 pub fn lookup_static_host(&self, query: &Query) -> Option<Lookup> {
61 if self.by_name.is_empty() {
62 return None;
63 }
64
65 let mut name = query.name().clone();
66 name.set_fqdn(true);
67 match query.query_type() {
68 RecordType::A | RecordType::AAAA => {
69 let val = self.by_name.get(&name)?;
70
71 match query.query_type() {
72 RecordType::A => val.a.clone(),
73 RecordType::AAAA => val.aaaa.clone(),
74 _ => None,
75 }
76 }
77 RecordType::PTR => {
78 let ip = name.parse_arpa_name().ok()?;
79
80 let ip_addr = ip.addr();
81 let records = self
82 .by_name
83 .iter()
84 .filter(|(_, v)| match ip_addr {
85 IpAddr::V4(ip) => match v.a.as_ref() {
86 Some(lookup) => lookup
87 .iter()
88 .any(|r| r.ip_addr().map(|it| it == ip).unwrap_or_default()),
89 None => false,
90 },
91 IpAddr::V6(ip) => match v.aaaa.as_ref() {
92 Some(lookup) => lookup
93 .iter()
94 .any(|r| r.ip_addr().map(|it| it == ip).unwrap_or_default()),
95 None => false,
96 },
97 })
98 .map(|(n, _)| {
99 Record::from_rdata(
100 name.clone(),
101 dns_lru::MAX_TTL,
102 RData::PTR(PTR(n.clone())),
103 )
104 })
105 .collect::<Arc<[Record]>>();
106
107 if records.is_empty() {
108 return None;
109 }
110
111 Some(Lookup::new_with_max_ttl(query.clone(), records))
112 }
113 _ => None,
114 }
115 }
116
117 pub fn insert(&mut self, mut name: Name, record_type: RecordType, lookup: Lookup) {
119 assert!(record_type == RecordType::A || record_type == RecordType::AAAA);
120
121 name.set_fqdn(true);
122 let lookup_type = self.by_name.entry(name.clone()).or_default();
123
124 let new_lookup = {
125 let old_lookup = match record_type {
126 RecordType::A => lookup_type.a.get_or_insert_with(|| {
127 let query = Query::query(name.clone(), record_type);
128 Lookup::new_with_max_ttl(query, Arc::from([]))
129 }),
130 RecordType::AAAA => lookup_type.aaaa.get_or_insert_with(|| {
131 let query = Query::query(name.clone(), record_type);
132 Lookup::new_with_max_ttl(query, Arc::from([]))
133 }),
134 _ => {
135 tracing::warn!("unsupported IP type from Hosts file: {:#?}", record_type);
136 return;
137 }
138 };
139
140 old_lookup.append(lookup)
141 };
142
143 match record_type {
145 RecordType::A => lookup_type.a = Some(new_lookup),
146 RecordType::AAAA => lookup_type.aaaa = Some(new_lookup),
147 _ => tracing::warn!("unsupported IP type from Hosts file"),
148 }
149 }
150
151 pub fn read_hosts_conf(&mut self, src: impl io::Read) -> io::Result<()> {
153 use std::io::{BufRead, BufReader};
154
155 for (line_index, line) in BufReader::new(src).lines().enumerate() {
162 let line = line?;
163
164 let line = if line_index == 0 && line.starts_with('\u{feff}') {
166 &line[3..]
168 } else {
169 &line
170 };
171
172 let line = match line.split_once('#') {
174 Some((line, _)) => line,
175 None => line,
176 }
177 .trim();
178
179 if line.is_empty() {
180 continue;
181 }
182
183 let mut iter = line.split_whitespace();
184 let addr = match iter.next() {
185 Some(addr) => match IpAddr::from_str(addr) {
186 Ok(addr) => RData::from(addr),
187 Err(_) => {
188 warn!("could not parse an IP from hosts file ({addr:?})");
189 continue;
190 }
191 },
192 None => continue,
193 };
194
195 for domain in iter {
196 let domain = domain.to_lowercase();
197 let Ok(mut name) = Name::from_str(&domain) else {
198 continue;
199 };
200
201 name.set_fqdn(true);
202 let record = Record::from_rdata(name.clone(), dns_lru::MAX_TTL, addr.clone());
203 match addr {
204 RData::A(..) => {
205 let query = Query::query(name.clone(), RecordType::A);
206 let lookup = Lookup::new_with_max_ttl(query, Arc::from([record]));
207 self.insert(name.clone(), RecordType::A, lookup);
208 }
209 RData::AAAA(..) => {
210 let query = Query::query(name.clone(), RecordType::AAAA);
211 let lookup = Lookup::new_with_max_ttl(query, Arc::from([record]));
212 self.insert(name.clone(), RecordType::AAAA, lookup);
213 }
214 _ => {
215 warn!("unsupported IP type from Hosts file: {:#?}", addr);
216 continue;
217 }
218 };
219
220 }
222 }
223
224 Ok(())
225 }
226}
227
228#[cfg(unix)]
229fn hosts_path() -> &'static str {
230 "/etc/hosts"
231}
232
233#[cfg(windows)]
234fn hosts_path() -> std::path::PathBuf {
235 let system_root =
236 std::env::var_os("SystemRoot").expect("Environment variable SystemRoot not found");
237 let system_root = Path::new(&system_root);
238 system_root.join("System32\\drivers\\etc\\hosts")
239}
240
241#[cfg(any(unix, windows))]
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use std::env;
246 use std::net::{Ipv4Addr, Ipv6Addr};
247
248 fn tests_dir() -> String {
249 let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
250 format! {"{server_path}/crates/resolver/tests"}
251 }
252
253 #[test]
254 fn test_read_hosts_conf() {
255 let path = format!("{}/hosts", tests_dir());
256 let hosts = Hosts::from_file(path).unwrap();
257
258 let name = Name::from_str("localhost.").unwrap();
259 let rdatas = hosts
260 .lookup_static_host(&Query::query(name.clone(), RecordType::A))
261 .unwrap()
262 .iter()
263 .map(ToOwned::to_owned)
264 .collect::<Vec<RData>>();
265
266 assert_eq!(rdatas, vec![RData::A(Ipv4Addr::LOCALHOST.into())]);
267
268 let rdatas = hosts
269 .lookup_static_host(&Query::query(name, RecordType::AAAA))
270 .unwrap()
271 .iter()
272 .map(ToOwned::to_owned)
273 .collect::<Vec<RData>>();
274
275 assert_eq!(
276 rdatas,
277 vec![RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into())]
278 );
279
280 let name = Name::from_str("broadcasthost").unwrap();
281 let rdatas = hosts
282 .lookup_static_host(&Query::query(name, RecordType::A))
283 .unwrap()
284 .iter()
285 .map(ToOwned::to_owned)
286 .collect::<Vec<RData>>();
287 assert_eq!(
288 rdatas,
289 vec![RData::A(Ipv4Addr::new(255, 255, 255, 255).into())]
290 );
291
292 let name = Name::from_str("example.com").unwrap();
293 let rdatas = hosts
294 .lookup_static_host(&Query::query(name, RecordType::A))
295 .unwrap()
296 .iter()
297 .map(ToOwned::to_owned)
298 .collect::<Vec<RData>>();
299 assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 102).into())]);
300
301 let name = Name::from_str("a.example.com").unwrap();
302 let rdatas = hosts
303 .lookup_static_host(&Query::query(name, RecordType::A))
304 .unwrap()
305 .iter()
306 .map(ToOwned::to_owned)
307 .collect::<Vec<RData>>();
308 assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 111).into())]);
309
310 let name = Name::from_str("b.example.com").unwrap();
311 let rdatas = hosts
312 .lookup_static_host(&Query::query(name, RecordType::A))
313 .unwrap()
314 .iter()
315 .map(ToOwned::to_owned)
316 .collect::<Vec<RData>>();
317 assert_eq!(rdatas, vec![RData::A(Ipv4Addr::new(10, 0, 1, 111).into())]);
318
319 let name = Name::from_str("111.1.0.10.in-addr.arpa.").unwrap();
320 let mut rdatas = hosts
321 .lookup_static_host(&Query::query(name, RecordType::PTR))
322 .unwrap()
323 .iter()
324 .map(ToOwned::to_owned)
325 .collect::<Vec<RData>>();
326 rdatas.sort_by_key(|r| r.as_ptr().as_ref().map(|p| p.0.clone()));
327 assert_eq!(
328 rdatas,
329 vec![
330 RData::PTR(PTR("a.example.com.".parse().unwrap())),
331 RData::PTR(PTR("b.example.com.".parse().unwrap()))
332 ]
333 );
334
335 let name = Name::from_str(
336 "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.",
337 )
338 .unwrap();
339 let rdatas = hosts
340 .lookup_static_host(&Query::query(name, RecordType::PTR))
341 .unwrap()
342 .iter()
343 .map(ToOwned::to_owned)
344 .collect::<Vec<RData>>();
345 assert_eq!(
346 rdatas,
347 vec![RData::PTR(PTR("localhost.".parse().unwrap())),]
348 );
349 }
350}