hickory_resolver/name_server/
name_server_stats.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::cmp::Ordering;
9use std::sync::{
10    atomic::{self, AtomicU32},
11    Arc,
12};
13
14use parking_lot::Mutex;
15use rand::Rng as _;
16
17#[cfg(not(test))]
18use std::time::{Duration, Instant};
19#[cfg(test)]
20use tokio::time::{Duration, Instant};
21
22pub(crate) struct NameServerStats {
23    /// The smoothed round-trip time (SRTT).
24    ///
25    /// This value represents an exponentially weighted moving average (EWMA) of
26    /// recorded latencies. The algorithm for computing this value is based on
27    /// the following:
28    ///
29    /// https://en.wikipedia.org/wiki/Moving_average#Application_to_measuring_computer_performance
30    ///
31    /// It is also partially inspired by the BIND and PowerDNS implementations:
32    ///
33    /// - https://github.com/isc-projects/bind9/blob/7bf8a7ab1b280c1021bf1e762a239b07aac3c591/lib/dns/adb.c#L3487
34    /// - https://github.com/PowerDNS/pdns/blob/7c5f9ae6ae4fb17302d933eaeebc8d6f0249aab2/pdns/syncres.cc#L123
35    ///
36    /// The algorithm for computing and using this value can be summarized as
37    /// follows:
38    ///
39    /// 1. The value is initialized to a random value that represents a very low
40    ///    latency.
41    /// 2. If the round-trip time (RTT) was successfully measured for a query,
42    ///    then it is incorporated into the EWMA using the formula linked above.
43    /// 3. If the RTT could not be measured (i.e. due to a connection failure),
44    ///    then a constant penalty factor is applied to the EWMA.
45    /// 4. When comparing EWMA values, a time-based decay is applied to each
46    ///    value. Note that this decay is only applied at read time.
47    ///
48    /// For the original discussion regarding this algorithm, see
49    /// https://github.com/hickory-dns/hickory-dns/issues/1702.
50    srtt_microseconds: AtomicU32,
51
52    /// The last time the `srtt_microseconds` value was updated.
53    last_update: Arc<Mutex<Option<Instant>>>,
54}
55
56impl Default for NameServerStats {
57    fn default() -> Self {
58        // Initialize the SRTT to a randomly generated value that represents a
59        // very low RTT. Such a value helps ensure that each server is attempted
60        // early.
61        Self::new(Duration::from_micros(rand::thread_rng().gen_range(1..32)))
62    }
63}
64
65/// Returns an exponentially weighted value in the range of 0.0 < x < 1.0
66///
67/// Computes the value using the following formula:
68///
69/// e<sup>(-t<sub>now</sub> - t<sub>last</sub>) / weight</sup>
70///
71/// As the duration since the `last_update` approaches the provided `weight`,
72/// the returned value decreases.
73fn compute_srtt_factor(last_update: Instant, weight: u32) -> f64 {
74    let exponent = (-last_update.elapsed().as_secs_f64().max(1.0)) / f64::from(weight);
75    exponent.exp()
76}
77
78impl NameServerStats {
79    const CONNECTION_FAILURE_PENALTY: u32 = Duration::from_millis(150).as_micros() as u32;
80    const MAX_SRTT_MICROS: u32 = Duration::from_secs(5).as_micros() as u32;
81
82    pub(crate) fn new(initial_srtt: Duration) -> Self {
83        Self {
84            srtt_microseconds: AtomicU32::new(initial_srtt.as_micros() as u32),
85            last_update: Arc::new(Mutex::new(None)),
86        }
87    }
88
89    /// Records the measured `rtt` for a particular query.
90    pub(crate) fn record_rtt(&self, rtt: Duration) {
91        // If the cast on the result does overflow (it shouldn't), then the
92        // value is saturated to u32::MAX, which is above the `MAX_SRTT_MICROS`
93        // limit (meaning that any potential overflow is inconsequential).
94        // See https://github.com/rust-lang/rust/issues/10184.
95        self.update_srtt(
96            rtt.as_micros() as u32,
97            |cur_srtt_microseconds, last_update| {
98                // An arbitrarily low weight is used when computing the factor
99                // to ensure that recent RTT measurements are weighted more
100                // heavily.
101                let factor = compute_srtt_factor(last_update, 3);
102                let new_srtt = (1.0 - factor) * (rtt.as_micros() as f64)
103                    + factor * f64::from(cur_srtt_microseconds);
104                new_srtt.round() as u32
105            },
106        );
107    }
108
109    /// Records a connection failure for a particular query.
110    pub(crate) fn record_connection_failure(&self) {
111        self.update_srtt(
112            Self::CONNECTION_FAILURE_PENALTY,
113            |cur_srtt_microseconds, _last_update| {
114                cur_srtt_microseconds.saturating_add(Self::CONNECTION_FAILURE_PENALTY)
115            },
116        );
117    }
118
119    /// Returns the raw SRTT value.
120    ///
121    /// Prefer to use `decayed_srtt` when ordering name servers.
122    fn srtt(&self) -> Duration {
123        Duration::from_micros(u64::from(
124            self.srtt_microseconds.load(atomic::Ordering::Acquire),
125        ))
126    }
127
128    /// Returns the SRTT value after applying a time based decay.
129    ///
130    /// The decay exponentially decreases the SRTT value. The primary reasons
131    /// for applying a downwards decay are twofold:
132    ///
133    /// 1. It helps distribute query load.
134    /// 2. It helps detect positive network changes. For example, decreases in
135    ///    latency or a server that has recovered from a failure.
136    fn decayed_srtt(&self) -> f64 {
137        let srtt = f64::from(self.srtt_microseconds.load(atomic::Ordering::Acquire));
138        self.last_update.lock().map_or(srtt, |last_update| {
139            // In general, if the time between queries is relatively short, then
140            // the server ordering algorithm will approximate a spike
141            // distribution where the servers with the lowest latencies are
142            // chosen much more frequently. Conversely, if the time between
143            // queries is relatively long, then the query distribution will be
144            // more uniform. A larger weight widens the window in which servers
145            // with historically lower latencies will be heavily preferred. On
146            // the other hand, a larger weight may also increase the time it
147            // takes to recover from a failure or to observe positive changes in
148            // latency.
149            srtt * compute_srtt_factor(last_update, 180)
150        })
151    }
152
153    /// Updates the SRTT value.
154    ///
155    /// If the `last_update` value has not been set, then uses the `default`
156    /// value to update the SRTT. Otherwise, invokes the `update_fn` with the
157    /// current SRTT value and the `last_update` timestamp.
158    fn update_srtt(&self, default: u32, update_fn: impl Fn(u32, Instant) -> u32) {
159        let last_update = self.last_update.lock().replace(Instant::now());
160        let _ = self.srtt_microseconds.fetch_update(
161            atomic::Ordering::SeqCst,
162            atomic::Ordering::SeqCst,
163            move |cur_srtt_microseconds| {
164                Some(
165                    last_update
166                        .map_or(default, |last_update| {
167                            update_fn(cur_srtt_microseconds, last_update)
168                        })
169                        .min(Self::MAX_SRTT_MICROS),
170                )
171            },
172        );
173    }
174}
175
176impl PartialEq for NameServerStats {
177    fn eq(&self, other: &Self) -> bool {
178        self.srtt() == other.srtt()
179    }
180}
181
182impl Eq for NameServerStats {}
183
184// TODO: Replace this with `f64::total_cmp` once the Rust version is bumped to
185// 1.62.0 (the method is stable beyond that version). In the meantime, the
186// implementation is copied from here:
187// https://github.com/rust-lang/rust/blob/master/library/core/src/num/f64.rs#L1336
188fn total_cmp(x: f64, y: f64) -> Ordering {
189    let mut left = x.to_bits() as i64;
190    let mut right = y.to_bits() as i64;
191
192    left ^= (((left >> 63) as u64) >> 1) as i64;
193    right ^= (((right >> 63) as u64) >> 1) as i64;
194
195    left.cmp(&right)
196}
197
198impl Ord for NameServerStats {
199    /// Custom implementation of Ord for NameServer which incorporates the
200    /// performance of the connection into it's ranking.
201    fn cmp(&self, other: &Self) -> Ordering {
202        total_cmp(self.decayed_srtt(), other.decayed_srtt())
203    }
204}
205
206impl PartialOrd for NameServerStats {
207    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
208        Some(self.cmp(other))
209    }
210}
211
212#[cfg(test)]
213#[allow(clippy::extra_unused_type_parameters)]
214mod tests {
215    use super::*;
216
217    fn is_send_sync<S: Sync + Send>() -> bool {
218        true
219    }
220
221    #[test]
222    fn stats_are_sync() {
223        assert!(is_send_sync::<NameServerStats>());
224    }
225
226    #[tokio::test(start_paused = true)]
227    async fn test_stats_cmp() {
228        let server_a = NameServerStats::new(Duration::from_micros(10));
229        let server_b = NameServerStats::new(Duration::from_micros(20));
230
231        // No RTTs or failures have been recorded. The initial SRTTs should be
232        // compared.
233        assert_eq!(server_a.cmp(&server_b), Ordering::Less);
234
235        // Server A was used. Unused server B should now be preferred.
236        server_a.record_rtt(Duration::from_millis(30));
237        tokio::time::advance(Duration::from_secs(5)).await;
238        assert_eq!(server_a.cmp(&server_b), Ordering::Greater);
239
240        // Both servers have been used. Server A has a lower SRTT and should be
241        // preferred.
242        server_b.record_rtt(Duration::from_millis(50));
243        tokio::time::advance(Duration::from_secs(5)).await;
244        assert_eq!(server_a.cmp(&server_b), Ordering::Less);
245
246        // Server A experiences a connection failure, which results in Server B
247        // being preferred.
248        server_a.record_connection_failure();
249        tokio::time::advance(Duration::from_secs(5)).await;
250        assert_eq!(server_a.cmp(&server_b), Ordering::Greater);
251
252        // Server A should eventually recover and once again be preferred.
253        while server_a.cmp(&server_b) != Ordering::Less {
254            server_b.record_rtt(Duration::from_millis(50));
255            tokio::time::advance(Duration::from_secs(5)).await;
256        }
257
258        server_a.record_rtt(Duration::from_millis(30));
259        tokio::time::advance(Duration::from_secs(3)).await;
260        assert_eq!(server_a.cmp(&server_b), Ordering::Less);
261    }
262
263    #[tokio::test(start_paused = true)]
264    async fn test_record_rtt() {
265        let server = NameServerStats::new(Duration::from_micros(10));
266
267        let first_rtt = Duration::from_millis(50);
268        server.record_rtt(first_rtt);
269
270        // The first recorded RTT should replace the initial value.
271        assert_eq!(server.srtt(), first_rtt);
272
273        tokio::time::advance(Duration::from_secs(3)).await;
274
275        // Subsequent RTTs should factor in previously recorded values.
276        server.record_rtt(Duration::from_millis(100));
277        assert_eq!(server.srtt(), Duration::from_micros(81606));
278    }
279
280    #[test]
281    fn test_record_rtt_maximum_value() {
282        let server = NameServerStats::new(Duration::from_micros(10));
283
284        server.record_rtt(Duration::MAX);
285        // Updates to the SRTT are capped at a maximum value.
286        assert_eq!(
287            server.srtt(),
288            Duration::from_micros(NameServerStats::MAX_SRTT_MICROS.into())
289        );
290    }
291
292    #[tokio::test(start_paused = true)]
293    async fn test_record_connection_failure() {
294        let server = NameServerStats::new(Duration::from_micros(10));
295
296        // Verify that the SRTT value is initially replaced with the penalty and
297        // subsequent failures result in the penalty being added.
298        for failure_count in 1..4 {
299            server.record_connection_failure();
300            assert_eq!(
301                server.srtt(),
302                Duration::from_micros(
303                    NameServerStats::CONNECTION_FAILURE_PENALTY
304                        .checked_mul(failure_count)
305                        .expect("checked_mul overflow")
306                        .into()
307                )
308            );
309            tokio::time::advance(Duration::from_secs(3)).await;
310        }
311
312        // Verify that the `last_update` timestamp was updated for a connection
313        // failure and is used in subsequent calculations.
314        server.record_rtt(Duration::from_millis(50));
315        assert_eq!(server.srtt(), Duration::from_micros(197152));
316    }
317
318    #[test]
319    fn test_record_connection_failure_maximum_value() {
320        let server = NameServerStats::new(Duration::from_micros(10));
321
322        let num_failures =
323            (NameServerStats::MAX_SRTT_MICROS / NameServerStats::CONNECTION_FAILURE_PENALTY) + 1;
324        for _ in 0..num_failures {
325            server.record_connection_failure();
326        }
327
328        // Updates to the SRTT are capped at a maximum value.
329        assert_eq!(
330            server.srtt(),
331            Duration::from_micros(NameServerStats::MAX_SRTT_MICROS.into())
332        );
333    }
334
335    #[tokio::test(start_paused = true)]
336    async fn test_decayed_srtt() {
337        let initial_srtt = 10;
338        let server = NameServerStats::new(Duration::from_micros(initial_srtt));
339
340        // No decay should be applied to the initial value.
341        assert_eq!(server.decayed_srtt() as u32, initial_srtt as u32);
342
343        tokio::time::advance(Duration::from_secs(5)).await;
344        server.record_rtt(Duration::from_millis(100));
345
346        // The decay function should assume a minimum of one second has elapsed
347        // since the last update.
348        tokio::time::advance(Duration::from_millis(500)).await;
349        assert_eq!(server.decayed_srtt() as u32, 99445);
350
351        tokio::time::advance(Duration::from_secs(5)).await;
352        assert_eq!(server.decayed_srtt() as u32, 96990);
353    }
354}