hickory_resolver/name_server/name_server_stats.rs
1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::cmp::Ordering;
9use std::sync::{
10 atomic::{self, AtomicU32},
11 Arc,
12};
13
14use parking_lot::Mutex;
15use rand::Rng as _;
16
17#[cfg(not(test))]
18use std::time::{Duration, Instant};
19#[cfg(test)]
20use tokio::time::{Duration, Instant};
21
22pub(crate) struct NameServerStats {
23 /// The smoothed round-trip time (SRTT).
24 ///
25 /// This value represents an exponentially weighted moving average (EWMA) of
26 /// recorded latencies. The algorithm for computing this value is based on
27 /// the following:
28 ///
29 /// https://en.wikipedia.org/wiki/Moving_average#Application_to_measuring_computer_performance
30 ///
31 /// It is also partially inspired by the BIND and PowerDNS implementations:
32 ///
33 /// - https://github.com/isc-projects/bind9/blob/7bf8a7ab1b280c1021bf1e762a239b07aac3c591/lib/dns/adb.c#L3487
34 /// - https://github.com/PowerDNS/pdns/blob/7c5f9ae6ae4fb17302d933eaeebc8d6f0249aab2/pdns/syncres.cc#L123
35 ///
36 /// The algorithm for computing and using this value can be summarized as
37 /// follows:
38 ///
39 /// 1. The value is initialized to a random value that represents a very low
40 /// latency.
41 /// 2. If the round-trip time (RTT) was successfully measured for a query,
42 /// then it is incorporated into the EWMA using the formula linked above.
43 /// 3. If the RTT could not be measured (i.e. due to a connection failure),
44 /// then a constant penalty factor is applied to the EWMA.
45 /// 4. When comparing EWMA values, a time-based decay is applied to each
46 /// value. Note that this decay is only applied at read time.
47 ///
48 /// For the original discussion regarding this algorithm, see
49 /// https://github.com/hickory-dns/hickory-dns/issues/1702.
50 srtt_microseconds: AtomicU32,
51
52 /// The last time the `srtt_microseconds` value was updated.
53 last_update: Arc<Mutex<Option<Instant>>>,
54}
55
56impl Default for NameServerStats {
57 fn default() -> Self {
58 // Initialize the SRTT to a randomly generated value that represents a
59 // very low RTT. Such a value helps ensure that each server is attempted
60 // early.
61 Self::new(Duration::from_micros(rand::thread_rng().gen_range(1..32)))
62 }
63}
64
65/// Returns an exponentially weighted value in the range of 0.0 < x < 1.0
66///
67/// Computes the value using the following formula:
68///
69/// e<sup>(-t<sub>now</sub> - t<sub>last</sub>) / weight</sup>
70///
71/// As the duration since the `last_update` approaches the provided `weight`,
72/// the returned value decreases.
73fn compute_srtt_factor(last_update: Instant, weight: u32) -> f64 {
74 let exponent = (-last_update.elapsed().as_secs_f64().max(1.0)) / f64::from(weight);
75 exponent.exp()
76}
77
78impl NameServerStats {
79 const CONNECTION_FAILURE_PENALTY: u32 = Duration::from_millis(150).as_micros() as u32;
80 const MAX_SRTT_MICROS: u32 = Duration::from_secs(5).as_micros() as u32;
81
82 pub(crate) fn new(initial_srtt: Duration) -> Self {
83 Self {
84 srtt_microseconds: AtomicU32::new(initial_srtt.as_micros() as u32),
85 last_update: Arc::new(Mutex::new(None)),
86 }
87 }
88
89 /// Records the measured `rtt` for a particular query.
90 pub(crate) fn record_rtt(&self, rtt: Duration) {
91 // If the cast on the result does overflow (it shouldn't), then the
92 // value is saturated to u32::MAX, which is above the `MAX_SRTT_MICROS`
93 // limit (meaning that any potential overflow is inconsequential).
94 // See https://github.com/rust-lang/rust/issues/10184.
95 self.update_srtt(
96 rtt.as_micros() as u32,
97 |cur_srtt_microseconds, last_update| {
98 // An arbitrarily low weight is used when computing the factor
99 // to ensure that recent RTT measurements are weighted more
100 // heavily.
101 let factor = compute_srtt_factor(last_update, 3);
102 let new_srtt = (1.0 - factor) * (rtt.as_micros() as f64)
103 + factor * f64::from(cur_srtt_microseconds);
104 new_srtt.round() as u32
105 },
106 );
107 }
108
109 /// Records a connection failure for a particular query.
110 pub(crate) fn record_connection_failure(&self) {
111 self.update_srtt(
112 Self::CONNECTION_FAILURE_PENALTY,
113 |cur_srtt_microseconds, _last_update| {
114 cur_srtt_microseconds.saturating_add(Self::CONNECTION_FAILURE_PENALTY)
115 },
116 );
117 }
118
119 /// Returns the raw SRTT value.
120 ///
121 /// Prefer to use `decayed_srtt` when ordering name servers.
122 fn srtt(&self) -> Duration {
123 Duration::from_micros(u64::from(
124 self.srtt_microseconds.load(atomic::Ordering::Acquire),
125 ))
126 }
127
128 /// Returns the SRTT value after applying a time based decay.
129 ///
130 /// The decay exponentially decreases the SRTT value. The primary reasons
131 /// for applying a downwards decay are twofold:
132 ///
133 /// 1. It helps distribute query load.
134 /// 2. It helps detect positive network changes. For example, decreases in
135 /// latency or a server that has recovered from a failure.
136 fn decayed_srtt(&self) -> f64 {
137 let srtt = f64::from(self.srtt_microseconds.load(atomic::Ordering::Acquire));
138 self.last_update.lock().map_or(srtt, |last_update| {
139 // In general, if the time between queries is relatively short, then
140 // the server ordering algorithm will approximate a spike
141 // distribution where the servers with the lowest latencies are
142 // chosen much more frequently. Conversely, if the time between
143 // queries is relatively long, then the query distribution will be
144 // more uniform. A larger weight widens the window in which servers
145 // with historically lower latencies will be heavily preferred. On
146 // the other hand, a larger weight may also increase the time it
147 // takes to recover from a failure or to observe positive changes in
148 // latency.
149 srtt * compute_srtt_factor(last_update, 180)
150 })
151 }
152
153 /// Updates the SRTT value.
154 ///
155 /// If the `last_update` value has not been set, then uses the `default`
156 /// value to update the SRTT. Otherwise, invokes the `update_fn` with the
157 /// current SRTT value and the `last_update` timestamp.
158 fn update_srtt(&self, default: u32, update_fn: impl Fn(u32, Instant) -> u32) {
159 let last_update = self.last_update.lock().replace(Instant::now());
160 let _ = self.srtt_microseconds.fetch_update(
161 atomic::Ordering::SeqCst,
162 atomic::Ordering::SeqCst,
163 move |cur_srtt_microseconds| {
164 Some(
165 last_update
166 .map_or(default, |last_update| {
167 update_fn(cur_srtt_microseconds, last_update)
168 })
169 .min(Self::MAX_SRTT_MICROS),
170 )
171 },
172 );
173 }
174}
175
176impl PartialEq for NameServerStats {
177 fn eq(&self, other: &Self) -> bool {
178 self.srtt() == other.srtt()
179 }
180}
181
182impl Eq for NameServerStats {}
183
184// TODO: Replace this with `f64::total_cmp` once the Rust version is bumped to
185// 1.62.0 (the method is stable beyond that version). In the meantime, the
186// implementation is copied from here:
187// https://github.com/rust-lang/rust/blob/master/library/core/src/num/f64.rs#L1336
188fn total_cmp(x: f64, y: f64) -> Ordering {
189 let mut left = x.to_bits() as i64;
190 let mut right = y.to_bits() as i64;
191
192 left ^= (((left >> 63) as u64) >> 1) as i64;
193 right ^= (((right >> 63) as u64) >> 1) as i64;
194
195 left.cmp(&right)
196}
197
198impl Ord for NameServerStats {
199 /// Custom implementation of Ord for NameServer which incorporates the
200 /// performance of the connection into it's ranking.
201 fn cmp(&self, other: &Self) -> Ordering {
202 total_cmp(self.decayed_srtt(), other.decayed_srtt())
203 }
204}
205
206impl PartialOrd for NameServerStats {
207 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
208 Some(self.cmp(other))
209 }
210}
211
212#[cfg(test)]
213#[allow(clippy::extra_unused_type_parameters)]
214mod tests {
215 use super::*;
216
217 fn is_send_sync<S: Sync + Send>() -> bool {
218 true
219 }
220
221 #[test]
222 fn stats_are_sync() {
223 assert!(is_send_sync::<NameServerStats>());
224 }
225
226 #[tokio::test(start_paused = true)]
227 async fn test_stats_cmp() {
228 let server_a = NameServerStats::new(Duration::from_micros(10));
229 let server_b = NameServerStats::new(Duration::from_micros(20));
230
231 // No RTTs or failures have been recorded. The initial SRTTs should be
232 // compared.
233 assert_eq!(server_a.cmp(&server_b), Ordering::Less);
234
235 // Server A was used. Unused server B should now be preferred.
236 server_a.record_rtt(Duration::from_millis(30));
237 tokio::time::advance(Duration::from_secs(5)).await;
238 assert_eq!(server_a.cmp(&server_b), Ordering::Greater);
239
240 // Both servers have been used. Server A has a lower SRTT and should be
241 // preferred.
242 server_b.record_rtt(Duration::from_millis(50));
243 tokio::time::advance(Duration::from_secs(5)).await;
244 assert_eq!(server_a.cmp(&server_b), Ordering::Less);
245
246 // Server A experiences a connection failure, which results in Server B
247 // being preferred.
248 server_a.record_connection_failure();
249 tokio::time::advance(Duration::from_secs(5)).await;
250 assert_eq!(server_a.cmp(&server_b), Ordering::Greater);
251
252 // Server A should eventually recover and once again be preferred.
253 while server_a.cmp(&server_b) != Ordering::Less {
254 server_b.record_rtt(Duration::from_millis(50));
255 tokio::time::advance(Duration::from_secs(5)).await;
256 }
257
258 server_a.record_rtt(Duration::from_millis(30));
259 tokio::time::advance(Duration::from_secs(3)).await;
260 assert_eq!(server_a.cmp(&server_b), Ordering::Less);
261 }
262
263 #[tokio::test(start_paused = true)]
264 async fn test_record_rtt() {
265 let server = NameServerStats::new(Duration::from_micros(10));
266
267 let first_rtt = Duration::from_millis(50);
268 server.record_rtt(first_rtt);
269
270 // The first recorded RTT should replace the initial value.
271 assert_eq!(server.srtt(), first_rtt);
272
273 tokio::time::advance(Duration::from_secs(3)).await;
274
275 // Subsequent RTTs should factor in previously recorded values.
276 server.record_rtt(Duration::from_millis(100));
277 assert_eq!(server.srtt(), Duration::from_micros(81606));
278 }
279
280 #[test]
281 fn test_record_rtt_maximum_value() {
282 let server = NameServerStats::new(Duration::from_micros(10));
283
284 server.record_rtt(Duration::MAX);
285 // Updates to the SRTT are capped at a maximum value.
286 assert_eq!(
287 server.srtt(),
288 Duration::from_micros(NameServerStats::MAX_SRTT_MICROS.into())
289 );
290 }
291
292 #[tokio::test(start_paused = true)]
293 async fn test_record_connection_failure() {
294 let server = NameServerStats::new(Duration::from_micros(10));
295
296 // Verify that the SRTT value is initially replaced with the penalty and
297 // subsequent failures result in the penalty being added.
298 for failure_count in 1..4 {
299 server.record_connection_failure();
300 assert_eq!(
301 server.srtt(),
302 Duration::from_micros(
303 NameServerStats::CONNECTION_FAILURE_PENALTY
304 .checked_mul(failure_count)
305 .expect("checked_mul overflow")
306 .into()
307 )
308 );
309 tokio::time::advance(Duration::from_secs(3)).await;
310 }
311
312 // Verify that the `last_update` timestamp was updated for a connection
313 // failure and is used in subsequent calculations.
314 server.record_rtt(Duration::from_millis(50));
315 assert_eq!(server.srtt(), Duration::from_micros(197152));
316 }
317
318 #[test]
319 fn test_record_connection_failure_maximum_value() {
320 let server = NameServerStats::new(Duration::from_micros(10));
321
322 let num_failures =
323 (NameServerStats::MAX_SRTT_MICROS / NameServerStats::CONNECTION_FAILURE_PENALTY) + 1;
324 for _ in 0..num_failures {
325 server.record_connection_failure();
326 }
327
328 // Updates to the SRTT are capped at a maximum value.
329 assert_eq!(
330 server.srtt(),
331 Duration::from_micros(NameServerStats::MAX_SRTT_MICROS.into())
332 );
333 }
334
335 #[tokio::test(start_paused = true)]
336 async fn test_decayed_srtt() {
337 let initial_srtt = 10;
338 let server = NameServerStats::new(Duration::from_micros(initial_srtt));
339
340 // No decay should be applied to the initial value.
341 assert_eq!(server.decayed_srtt() as u32, initial_srtt as u32);
342
343 tokio::time::advance(Duration::from_secs(5)).await;
344 server.record_rtt(Duration::from_millis(100));
345
346 // The decay function should assume a minimum of one second has elapsed
347 // since the last update.
348 tokio::time::advance(Duration::from_millis(500)).await;
349 assert_eq!(server.decayed_srtt() as u32, 99445);
350
351 tokio::time::advance(Duration::from_secs(5)).await;
352 assert_eq!(server.decayed_srtt() as u32, 96990);
353 }
354}