1use std::{convert::TryInto, io::Write};
2
3use crate::{rdata::OPT, ResourceRecord};
4
5use super::{PacketFlag, OPCODE, RCODE};
6
7pub(crate) mod masks {
8 pub const OPCODE_MASK: u16 = 0b0111_1000_0000_0000;
9 pub const RESERVED_MASK: u16 = 0b0000_0000_0100_0000;
10 pub const RESPONSE_CODE_MASK: u16 = 0b0000_0000_0000_1111;
11}
12#[derive(Debug, Clone)]
14pub(crate) struct Header<'a> {
15 pub id: u16,
17 pub opcode: OPCODE,
19 pub response_code: RCODE,
21
22 pub z_flags: PacketFlag,
23
24 pub opt: Option<OPT<'a>>,
25}
26
27impl<'a> Header<'a> {
28 pub fn new_query(id: u16) -> Self {
30 Self {
31 id,
32 opcode: OPCODE::StandardQuery,
33 response_code: RCODE::NoError,
34 z_flags: PacketFlag::empty(),
35 opt: None,
36 }
37 }
38
39 pub fn new_reply(id: u16, opcode: OPCODE) -> Self {
41 Self {
42 id,
43 opcode,
44 response_code: RCODE::NoError,
45 z_flags: PacketFlag::RESPONSE,
46 opt: None,
47 }
48 }
49
50 pub fn set_flags(&mut self, flags: PacketFlag) {
51 self.z_flags |= flags;
52 }
53
54 pub fn remove_flags(&mut self, flags: PacketFlag) {
55 self.z_flags.remove(flags);
56 }
57
58 pub fn has_flags(&self, flags: PacketFlag) -> bool {
59 self.z_flags.contains(flags)
60 }
61
62 pub fn parse(data: &[u8]) -> crate::Result<Self> {
64 if data.len() < 12 {
65 return Err(crate::SimpleDnsError::InsufficientData);
66 }
67
68 let flags = u16::from_be_bytes(data[2..4].try_into()?);
69 if flags & masks::RESERVED_MASK != 0 {
70 return Err(crate::SimpleDnsError::InvalidHeaderData);
71 }
72
73 let header = Self {
74 id: u16::from_be_bytes(data[..2].try_into()?),
75 opcode: ((flags & masks::OPCODE_MASK) >> masks::OPCODE_MASK.trailing_zeros()).into(),
76 response_code: (flags & masks::RESPONSE_CODE_MASK).into(),
77 z_flags: PacketFlag::from_bits_truncate(flags),
78 opt: None,
79 };
80 Ok(header)
81 }
82
83 pub fn write_to<T: Write>(
85 &self,
86 buffer: &mut T,
87 questions: u16,
88 answers: u16,
89 name_servers: u16,
90 additional_records: u16,
91 ) -> crate::Result<()> {
92 buffer.write_all(&self.id.to_be_bytes())?;
93 buffer.write_all(&self.get_flags().to_be_bytes())?;
94 buffer.write_all(&questions.to_be_bytes())?;
95 buffer.write_all(&answers.to_be_bytes())?;
96 buffer.write_all(&name_servers.to_be_bytes())?;
97 buffer.write_all(&additional_records.to_be_bytes())?;
98
99 Ok(())
100 }
101
102 fn get_flags(&self) -> u16 {
103 let mut flags = self.z_flags.bits();
104
105 flags |= (self.opcode as u16) << masks::OPCODE_MASK.trailing_zeros();
106 flags |= self.response_code as u16 & masks::RESPONSE_CODE_MASK;
107
108 flags
109 }
110
111 pub(crate) fn opt_rr(&self) -> Option<ResourceRecord> {
112 self.opt.as_ref().map(|opt| {
113 ResourceRecord::new(
114 crate::Name::new_unchecked("."),
115 crate::CLASS::IN,
116 opt.encode_ttl(self),
117 crate::rdata::RData::OPT(opt.clone()),
118 )
119 })
120 }
121
122 pub(crate) fn extract_info_from_opt_rr(&mut self, opt_rr: Option<ResourceRecord<'a>>) {
123 if let Some(opt) = opt_rr {
124 self.response_code = OPT::extract_rcode_from_ttl(opt.ttl, self);
125 self.opt = match opt.rdata {
126 crate::rdata::RData::OPT(opt) => Some(opt),
127 _ => unreachable!(),
128 };
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use crate::header_buffer;
136
137 use super::*;
138
139 #[test]
140 fn write_example_query() {
141 let mut header = Header::new_query(u16::MAX);
142
143 header.set_flags(PacketFlag::TRUNCATION | PacketFlag::RECURSION_DESIRED);
144
145 let mut buf = vec![];
146 header.write_to(&mut buf, 0, 0, 0, 0).unwrap();
147
148 assert_eq!(
149 b"\xff\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00",
150 &buf[..]
151 );
152 }
153
154 #[test]
155 fn parse_example_query() {
156 let buffer = b"\xff\xff\x03\x00\x00\x02\x00\x02\x00\x02\x00\x02";
157 let header = Header::parse(&buffer[..]).unwrap();
158
159 assert_eq!(u16::MAX, header.id);
160 assert_eq!(OPCODE::StandardQuery, header.opcode);
161 assert!(!header.has_flags(
162 PacketFlag::AUTHORITATIVE_ANSWER
163 | PacketFlag::RECURSION_AVAILABLE
164 | PacketFlag::RESPONSE
165 ));
166 assert!(header.has_flags(PacketFlag::TRUNCATION | PacketFlag::RECURSION_DESIRED));
167 assert_eq!(RCODE::NoError, header.response_code);
168 assert_eq!(2, header_buffer::additional_records(&buffer[..]).unwrap());
169 assert_eq!(2, header_buffer::answers(&buffer[..]).unwrap());
170 assert_eq!(2, header_buffer::name_servers(&buffer[..]).unwrap());
171 assert_eq!(2, header_buffer::questions(&buffer[..]).unwrap());
172 }
173
174 #[test]
175 fn read_write_questions_count() {
176 let mut buffer = [0u8; 12];
177 header_buffer::set_questions(&mut buffer, 1);
178 assert_eq!(1, header_buffer::questions(&buffer).unwrap());
179 }
180
181 #[test]
182 fn read_write_answers_count() {
183 let mut buffer = [0u8; 12];
184 header_buffer::set_answers(&mut buffer, 1);
185 assert_eq!(1, header_buffer::answers(&buffer).unwrap());
186 }
187
188 #[test]
189 fn read_write_name_servers_count() {
190 let mut buffer = [0u8; 12];
191 header_buffer::set_name_servers(&mut buffer, 1);
192 assert_eq!(1, header_buffer::name_servers(&buffer).unwrap());
193 }
194
195 #[test]
196 fn read_write_additional_records_count() {
197 let mut buffer = [0u8; 12];
198 header_buffer::set_additional_records(&mut buffer, 1);
199 assert_eq!(1, header_buffer::additional_records(&buffer).unwrap());
200 }
201
202 #[test]
203 fn big_rcode_doesnt_break_header() {
204 let mut header = Header::new_reply(1, OPCODE::StandardQuery);
205 header.response_code = RCODE::BADVERS;
206
207 let mut buffer = vec![];
208 header.write_to(&mut buffer, 0, 0, 0, 0).unwrap();
209
210 assert_ne!(RCODE::BADVERS, header_buffer::rcode(&buffer[..]).unwrap());
211
212 let header = Header::parse(&buffer[..]).expect("Header parsing failed");
213 assert_eq!(RCODE::NoError, header.response_code);
214 assert!(header.has_flags(PacketFlag::RESPONSE));
215 }
216}