Skip to main content

forge/cmd/
snapshot.rs

1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use eyre::{Context, Result};
6use foundry_cli::utils::STATIC_FUZZ_SEED;
7use regex::Regex;
8use std::{
9    cmp::Ordering,
10    fs,
11    io::{self, BufRead},
12    path::{Path, PathBuf},
13    str::FromStr,
14    sync::LazyLock,
15};
16use yansi::Paint;
17
18/// A regex that matches a basic snapshot entry like
19/// `Test:testDeposit() (gas: 58804)`
20pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
21    Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
22});
23
24/// CLI arguments for `forge snapshot`.
25#[derive(Clone, Debug, Parser)]
26pub struct GasSnapshotArgs {
27    /// Output a diff against a pre-existing gas snapshot.
28    ///
29    /// By default, the comparison is done with .gas-snapshot.
30    #[arg(
31        conflicts_with = "snap",
32        long,
33        value_hint = ValueHint::FilePath,
34        value_name = "SNAPSHOT_FILE",
35    )]
36    diff: Option<Option<PathBuf>>,
37
38    /// Compare against a pre-existing gas snapshot, exiting with code 1 if they do not match.
39    ///
40    /// Outputs a diff if the gas snapshots do not match.
41    ///
42    /// By default, the comparison is done with .gas-snapshot.
43    #[arg(
44        conflicts_with = "diff",
45        long,
46        value_hint = ValueHint::FilePath,
47        value_name = "SNAPSHOT_FILE",
48    )]
49    check: Option<Option<PathBuf>>,
50
51    // Hidden because there is only one option
52    /// How to format the output.
53    #[arg(long, hide(true))]
54    format: Option<Format>,
55
56    /// Output file for the gas snapshot.
57    #[arg(
58        long,
59        default_value = ".gas-snapshot",
60        value_hint = ValueHint::FilePath,
61        value_name = "FILE",
62    )]
63    snap: PathBuf,
64
65    /// Tolerates gas deviations up to the specified percentage.
66    #[arg(
67        long,
68        value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
69        value_name = "SNAPSHOT_THRESHOLD"
70    )]
71    tolerance: Option<u32>,
72
73    /// All test arguments are supported
74    #[command(flatten)]
75    pub(crate) test: test::TestArgs,
76
77    /// Additional configs for test results
78    #[command(flatten)]
79    config: GasSnapshotConfig,
80}
81
82impl GasSnapshotArgs {
83    /// Returns whether `GasSnapshotArgs` was configured with `--watch`
84    pub fn is_watch(&self) -> bool {
85        self.test.is_watch()
86    }
87
88    /// Returns the [`watchexec::Config`] necessary to bootstrap a new watch loop.
89    pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
90        self.test.watchexec_config()
91    }
92
93    pub async fn run(mut self) -> Result<()> {
94        // Set fuzz seed so gas snapshots are deterministic
95        self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
96
97        let outcome = self.test.execute_tests().await?;
98        outcome.ensure_ok(false)?;
99        let tests = self.config.apply(outcome);
100
101        if let Some(path) = self.diff {
102            let snap = path.as_ref().unwrap_or(&self.snap);
103            let snaps = read_gas_snapshot(snap)?;
104            diff(tests, snaps)?;
105        } else if let Some(path) = self.check {
106            let snap = path.as_ref().unwrap_or(&self.snap);
107            let snaps = read_gas_snapshot(snap)?;
108            if check(tests, snaps, self.tolerance) {
109                std::process::exit(0)
110            } else {
111                std::process::exit(1)
112            }
113        } else {
114            write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
115        }
116        Ok(())
117    }
118}
119
120// TODO implement pretty tables
121#[derive(Clone, Debug)]
122pub enum Format {
123    Table,
124}
125
126impl FromStr for Format {
127    type Err = String;
128
129    fn from_str(s: &str) -> Result<Self, Self::Err> {
130        match s {
131            "t" | "table" => Ok(Self::Table),
132            _ => Err(format!("Unrecognized format `{s}`")),
133        }
134    }
135}
136
137/// Additional filters that can be applied on the test results
138#[derive(Clone, Debug, Default, Parser)]
139struct GasSnapshotConfig {
140    /// Sort results by gas used (ascending).
141    #[arg(long)]
142    asc: bool,
143
144    /// Sort results by gas used (descending).
145    #[arg(conflicts_with = "asc", long)]
146    desc: bool,
147
148    /// Only include tests that used more gas that the given amount.
149    #[arg(long, value_name = "MIN_GAS")]
150    min: Option<u64>,
151
152    /// Only include tests that used less gas that the given amount.
153    #[arg(long, value_name = "MAX_GAS")]
154    max: Option<u64>,
155}
156
157impl GasSnapshotConfig {
158    fn is_in_gas_range(&self, gas_used: u64) -> bool {
159        if let Some(min) = self.min
160            && gas_used < min
161        {
162            return false;
163        }
164        if let Some(max) = self.max
165            && gas_used > max
166        {
167            return false;
168        }
169        true
170    }
171
172    fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
173        let mut tests = outcome
174            .into_tests()
175            .filter(|test| self.is_in_gas_range(test.gas_used()))
176            .collect::<Vec<_>>();
177
178        if self.asc {
179            tests.sort_by_key(|a| a.gas_used());
180        } else if self.desc {
181            tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
182        }
183
184        tests
185    }
186}
187
188/// A general entry in a gas snapshot file
189///
190/// Has the form:
191///   `<signature>(gas:? 40181)` for normal tests
192///   `<signature>(runs: 256, μ: 40181, ~: 40181)` for fuzz tests
193///   `<signature>(runs: 256, calls: 40181, reverts: 40181)` for invariant tests
194#[derive(Clone, Debug, PartialEq, Eq)]
195pub struct GasSnapshotEntry {
196    pub contract_name: String,
197    pub signature: String,
198    pub gas_used: TestKindReport,
199}
200
201impl FromStr for GasSnapshotEntry {
202    type Err = String;
203
204    fn from_str(s: &str) -> Result<Self, Self::Err> {
205        RE_BASIC_SNAPSHOT_ENTRY
206            .captures(s)
207            .and_then(|cap| {
208                cap.name("file").and_then(|file| {
209                    cap.name("sig").and_then(|sig| {
210                        if let Some(gas) = cap.name("gas") {
211                            Some(Self {
212                                contract_name: file.as_str().to_string(),
213                                signature: sig.as_str().to_string(),
214                                gas_used: TestKindReport::Unit {
215                                    gas: gas.as_str().parse().unwrap(),
216                                },
217                            })
218                        } else if let Some(runs) = cap.name("runs") {
219                            cap.name("avg")
220                                .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
221                                .map(|(runs, avg, med)| Self {
222                                    contract_name: file.as_str().to_string(),
223                                    signature: sig.as_str().to_string(),
224                                    gas_used: TestKindReport::Fuzz {
225                                        runs: runs.as_str().parse().unwrap(),
226                                        median_gas: med.as_str().parse().unwrap(),
227                                        mean_gas: avg.as_str().parse().unwrap(),
228                                    },
229                                })
230                        } else {
231                            cap.name("invruns")
232                                .and_then(|runs| {
233                                    cap.name("calls").and_then(|avg| {
234                                        cap.name("reverts").map(|med| (runs, avg, med))
235                                    })
236                                })
237                                .map(|(runs, calls, reverts)| Self {
238                                    contract_name: file.as_str().to_string(),
239                                    signature: sig.as_str().to_string(),
240                                    gas_used: TestKindReport::Invariant {
241                                        runs: runs.as_str().parse().unwrap(),
242                                        calls: calls.as_str().parse().unwrap(),
243                                        reverts: reverts.as_str().parse().unwrap(),
244                                        metrics: HashMap::default(),
245                                        failed_corpus_replays: 0,
246                                    },
247                                })
248                        }
249                    })
250                })
251            })
252            .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
253    }
254}
255
256/// Reads a list of gas snapshot entries from a gas snapshot file.
257fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
258    let path = path.as_ref();
259    let mut entries = Vec::new();
260    for line in io::BufReader::new(
261        fs::File::open(path)
262            .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
263    )
264    .lines()
265    {
266        entries
267            .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
268    }
269    Ok(entries)
270}
271
272/// Writes a series of tests to a gas snapshot file after sorting them.
273fn write_to_gas_snapshot_file(
274    tests: &[SuiteTestResult],
275    path: impl AsRef<Path>,
276    _format: Option<Format>,
277) -> Result<()> {
278    let mut reports = tests
279        .iter()
280        .map(|test| {
281            format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
282        })
283        .collect::<Vec<_>>();
284
285    // sort all reports
286    reports.sort();
287
288    let content = reports.join("\n");
289    Ok(fs::write(path, content)?)
290}
291
292/// A Gas snapshot entry diff.
293#[derive(Clone, Debug, PartialEq, Eq)]
294pub struct GasSnapshotDiff {
295    pub signature: String,
296    pub source_gas_used: TestKindReport,
297    pub target_gas_used: TestKindReport,
298}
299
300impl GasSnapshotDiff {
301    /// Returns the gas diff
302    ///
303    /// `> 0` if the source used more gas
304    /// `< 0` if the target used more gas
305    fn gas_change(&self) -> i128 {
306        self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
307    }
308
309    /// Determines the percentage change
310    fn gas_diff(&self) -> f64 {
311        self.gas_change() as f64 / self.target_gas_used.gas() as f64
312    }
313}
314
315/// Compares the set of tests with an existing gas snapshot.
316///
317/// Returns true all tests match
318fn check(
319    tests: Vec<SuiteTestResult>,
320    snaps: Vec<GasSnapshotEntry>,
321    tolerance: Option<u32>,
322) -> bool {
323    let snaps = snaps
324        .into_iter()
325        .map(|s| ((s.contract_name, s.signature), s.gas_used))
326        .collect::<HashMap<_, _>>();
327    let mut has_diff = false;
328    for test in tests {
329        if let Some(target_gas) =
330            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
331        {
332            let source_gas = test.result.kind.report();
333            if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
334                let _ = sh_println!(
335                    "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
336                    test.contract_name(),
337                    test.signature,
338                    source_gas,
339                    target_gas
340                );
341                has_diff = true;
342            }
343        } else {
344            let _ = sh_println!(
345                "No matching snapshot entry found for \"{}::{}\" in snapshot file",
346                test.contract_name(),
347                test.signature
348            );
349            has_diff = true;
350        }
351    }
352    !has_diff
353}
354
355/// Compare the set of tests with an existing gas snapshot.
356fn diff(tests: Vec<SuiteTestResult>, snaps: Vec<GasSnapshotEntry>) -> Result<()> {
357    let snaps = snaps
358        .into_iter()
359        .map(|s| ((s.contract_name, s.signature), s.gas_used))
360        .collect::<HashMap<_, _>>();
361    let mut diffs = Vec::with_capacity(tests.len());
362    for test in tests.into_iter() {
363        if let Some(target_gas_used) =
364            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
365        {
366            diffs.push(GasSnapshotDiff {
367                source_gas_used: test.result.kind.report(),
368                signature: test.signature,
369                target_gas_used,
370            });
371        }
372    }
373    let mut overall_gas_change = 0i128;
374    let mut overall_gas_used = 0i128;
375
376    diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
377
378    for diff in diffs {
379        let gas_change = diff.gas_change();
380        overall_gas_change += gas_change;
381        overall_gas_used += diff.target_gas_used.gas() as i128;
382        let gas_diff = diff.gas_diff();
383        sh_println!(
384            "{} (gas: {} ({})) ",
385            diff.signature,
386            fmt_change(gas_change),
387            fmt_pct_change(gas_diff)
388        )?;
389    }
390
391    let overall_gas_diff = overall_gas_change as f64 / overall_gas_used as f64;
392    sh_println!(
393        "Overall gas change: {} ({})",
394        fmt_change(overall_gas_change),
395        fmt_pct_change(overall_gas_diff)
396    )?;
397    Ok(())
398}
399
400fn fmt_pct_change(change: f64) -> String {
401    let change_pct = change * 100.0;
402    match change.total_cmp(&0.0) {
403        Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
404        Ordering::Equal => {
405            format!("{change_pct:.3}%")
406        }
407        Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
408    }
409}
410
411fn fmt_change(change: i128) -> String {
412    match change.cmp(&0) {
413        Ordering::Less => format!("{change}").green().to_string(),
414        Ordering::Equal => {
415            format!("{change}")
416        }
417        Ordering::Greater => format!("{change}").red().to_string(),
418    }
419}
420
421/// Returns true of the difference between the gas values exceeds the tolerance
422///
423/// If `tolerance` is `None`, then this returns `true` if both gas values are equal
424fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
425    if let Some(tolerance) = tolerance_pct {
426        let (hi, lo) = if source_gas > target_gas {
427            (source_gas, target_gas)
428        } else {
429            (target_gas, source_gas)
430        };
431        let diff = (1. - (lo as f64 / hi as f64)) * 100.;
432        diff < tolerance as f64
433    } else {
434        source_gas == target_gas
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_tolerance() {
444        assert!(within_tolerance(100, 105, Some(5)));
445        assert!(within_tolerance(105, 100, Some(5)));
446        assert!(!within_tolerance(100, 106, Some(5)));
447        assert!(!within_tolerance(106, 100, Some(5)));
448        assert!(within_tolerance(100, 100, None));
449    }
450
451    #[test]
452    fn can_parse_basic_gas_snapshot_entry() {
453        let s = "Test:deposit() (gas: 7222)";
454        let entry = GasSnapshotEntry::from_str(s).unwrap();
455        assert_eq!(
456            entry,
457            GasSnapshotEntry {
458                contract_name: "Test".to_string(),
459                signature: "deposit()".to_string(),
460                gas_used: TestKindReport::Unit { gas: 7222 }
461            }
462        );
463    }
464
465    #[test]
466    fn can_parse_fuzz_gas_snapshot_entry() {
467        let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
468        let entry = GasSnapshotEntry::from_str(s).unwrap();
469        assert_eq!(
470            entry,
471            GasSnapshotEntry {
472                contract_name: "Test".to_string(),
473                signature: "deposit()".to_string(),
474                gas_used: TestKindReport::Fuzz { runs: 256, median_gas: 200, mean_gas: 100 }
475            }
476        );
477    }
478
479    #[test]
480    fn can_parse_invariant_gas_snapshot_entry() {
481        let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
482        let entry = GasSnapshotEntry::from_str(s).unwrap();
483        assert_eq!(
484            entry,
485            GasSnapshotEntry {
486                contract_name: "Test".to_string(),
487                signature: "deposit()".to_string(),
488                gas_used: TestKindReport::Invariant {
489                    runs: 256,
490                    calls: 100,
491                    reverts: 200,
492                    metrics: HashMap::default(),
493                    failed_corpus_replays: 0,
494                }
495            }
496        );
497    }
498
499    #[test]
500    fn can_parse_invariant_gas_snapshot_entry2() {
501        let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
502        let entry = GasSnapshotEntry::from_str(s).unwrap();
503        assert_eq!(
504            entry,
505            GasSnapshotEntry {
506                contract_name: "ERC20Invariants".to_string(),
507                signature: "invariantBalanceSum()".to_string(),
508                gas_used: TestKindReport::Invariant {
509                    runs: 256,
510                    calls: 3840,
511                    reverts: 2388,
512                    metrics: HashMap::default(),
513                    failed_corpus_replays: 0,
514                }
515            }
516        );
517    }
518}