1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use eyre::{Context, Result};
6use foundry_cli::utils::STATIC_FUZZ_SEED;
7use regex::Regex;
8use std::{
9 cmp::Ordering,
10 fs,
11 io::{self, BufRead},
12 path::{Path, PathBuf},
13 str::FromStr,
14 sync::LazyLock,
15};
16use yansi::Paint;
17
18pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
21 Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
22});
23
24#[derive(Clone, Debug, Parser)]
26pub struct GasSnapshotArgs {
27 #[arg(
31 conflicts_with = "snap",
32 long,
33 value_hint = ValueHint::FilePath,
34 value_name = "SNAPSHOT_FILE",
35 )]
36 diff: Option<Option<PathBuf>>,
37
38 #[arg(
44 conflicts_with = "diff",
45 long,
46 value_hint = ValueHint::FilePath,
47 value_name = "SNAPSHOT_FILE",
48 )]
49 check: Option<Option<PathBuf>>,
50
51 #[arg(long, hide(true))]
54 format: Option<Format>,
55
56 #[arg(
58 long,
59 default_value = ".gas-snapshot",
60 value_hint = ValueHint::FilePath,
61 value_name = "FILE",
62 )]
63 snap: PathBuf,
64
65 #[arg(
67 long,
68 value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
69 value_name = "SNAPSHOT_THRESHOLD"
70 )]
71 tolerance: Option<u32>,
72
73 #[command(flatten)]
75 pub(crate) test: test::TestArgs,
76
77 #[command(flatten)]
79 config: GasSnapshotConfig,
80}
81
82impl GasSnapshotArgs {
83 pub fn is_watch(&self) -> bool {
85 self.test.is_watch()
86 }
87
88 pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
90 self.test.watchexec_config()
91 }
92
93 pub async fn run(mut self) -> Result<()> {
94 self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
96
97 let outcome = self.test.execute_tests().await?;
98 outcome.ensure_ok(false)?;
99 let tests = self.config.apply(outcome);
100
101 if let Some(path) = self.diff {
102 let snap = path.as_ref().unwrap_or(&self.snap);
103 let snaps = read_gas_snapshot(snap)?;
104 diff(tests, snaps)?;
105 } else if let Some(path) = self.check {
106 let snap = path.as_ref().unwrap_or(&self.snap);
107 let snaps = read_gas_snapshot(snap)?;
108 if check(tests, snaps, self.tolerance) {
109 std::process::exit(0)
110 } else {
111 std::process::exit(1)
112 }
113 } else {
114 write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
115 }
116 Ok(())
117 }
118}
119
120#[derive(Clone, Debug)]
122pub enum Format {
123 Table,
124}
125
126impl FromStr for Format {
127 type Err = String;
128
129 fn from_str(s: &str) -> Result<Self, Self::Err> {
130 match s {
131 "t" | "table" => Ok(Self::Table),
132 _ => Err(format!("Unrecognized format `{s}`")),
133 }
134 }
135}
136
137#[derive(Clone, Debug, Default, Parser)]
139struct GasSnapshotConfig {
140 #[arg(long)]
142 asc: bool,
143
144 #[arg(conflicts_with = "asc", long)]
146 desc: bool,
147
148 #[arg(long, value_name = "MIN_GAS")]
150 min: Option<u64>,
151
152 #[arg(long, value_name = "MAX_GAS")]
154 max: Option<u64>,
155}
156
157impl GasSnapshotConfig {
158 fn is_in_gas_range(&self, gas_used: u64) -> bool {
159 if let Some(min) = self.min
160 && gas_used < min
161 {
162 return false;
163 }
164 if let Some(max) = self.max
165 && gas_used > max
166 {
167 return false;
168 }
169 true
170 }
171
172 fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
173 let mut tests = outcome
174 .into_tests()
175 .filter(|test| self.is_in_gas_range(test.gas_used()))
176 .collect::<Vec<_>>();
177
178 if self.asc {
179 tests.sort_by_key(|a| a.gas_used());
180 } else if self.desc {
181 tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
182 }
183
184 tests
185 }
186}
187
188#[derive(Clone, Debug, PartialEq, Eq)]
195pub struct GasSnapshotEntry {
196 pub contract_name: String,
197 pub signature: String,
198 pub gas_used: TestKindReport,
199}
200
201impl FromStr for GasSnapshotEntry {
202 type Err = String;
203
204 fn from_str(s: &str) -> Result<Self, Self::Err> {
205 RE_BASIC_SNAPSHOT_ENTRY
206 .captures(s)
207 .and_then(|cap| {
208 cap.name("file").and_then(|file| {
209 cap.name("sig").and_then(|sig| {
210 if let Some(gas) = cap.name("gas") {
211 Some(Self {
212 contract_name: file.as_str().to_string(),
213 signature: sig.as_str().to_string(),
214 gas_used: TestKindReport::Unit {
215 gas: gas.as_str().parse().unwrap(),
216 },
217 })
218 } else if let Some(runs) = cap.name("runs") {
219 cap.name("avg")
220 .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
221 .map(|(runs, avg, med)| Self {
222 contract_name: file.as_str().to_string(),
223 signature: sig.as_str().to_string(),
224 gas_used: TestKindReport::Fuzz {
225 runs: runs.as_str().parse().unwrap(),
226 median_gas: med.as_str().parse().unwrap(),
227 mean_gas: avg.as_str().parse().unwrap(),
228 },
229 })
230 } else {
231 cap.name("invruns")
232 .and_then(|runs| {
233 cap.name("calls").and_then(|avg| {
234 cap.name("reverts").map(|med| (runs, avg, med))
235 })
236 })
237 .map(|(runs, calls, reverts)| Self {
238 contract_name: file.as_str().to_string(),
239 signature: sig.as_str().to_string(),
240 gas_used: TestKindReport::Invariant {
241 runs: runs.as_str().parse().unwrap(),
242 calls: calls.as_str().parse().unwrap(),
243 reverts: reverts.as_str().parse().unwrap(),
244 metrics: HashMap::default(),
245 failed_corpus_replays: 0,
246 },
247 })
248 }
249 })
250 })
251 })
252 .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
253 }
254}
255
256fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
258 let path = path.as_ref();
259 let mut entries = Vec::new();
260 for line in io::BufReader::new(
261 fs::File::open(path)
262 .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
263 )
264 .lines()
265 {
266 entries
267 .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
268 }
269 Ok(entries)
270}
271
272fn write_to_gas_snapshot_file(
274 tests: &[SuiteTestResult],
275 path: impl AsRef<Path>,
276 _format: Option<Format>,
277) -> Result<()> {
278 let mut reports = tests
279 .iter()
280 .map(|test| {
281 format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
282 })
283 .collect::<Vec<_>>();
284
285 reports.sort();
287
288 let content = reports.join("\n");
289 Ok(fs::write(path, content)?)
290}
291
292#[derive(Clone, Debug, PartialEq, Eq)]
294pub struct GasSnapshotDiff {
295 pub signature: String,
296 pub source_gas_used: TestKindReport,
297 pub target_gas_used: TestKindReport,
298}
299
300impl GasSnapshotDiff {
301 fn gas_change(&self) -> i128 {
306 self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
307 }
308
309 fn gas_diff(&self) -> f64 {
311 self.gas_change() as f64 / self.target_gas_used.gas() as f64
312 }
313}
314
315fn check(
319 tests: Vec<SuiteTestResult>,
320 snaps: Vec<GasSnapshotEntry>,
321 tolerance: Option<u32>,
322) -> bool {
323 let snaps = snaps
324 .into_iter()
325 .map(|s| ((s.contract_name, s.signature), s.gas_used))
326 .collect::<HashMap<_, _>>();
327 let mut has_diff = false;
328 for test in tests {
329 if let Some(target_gas) =
330 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
331 {
332 let source_gas = test.result.kind.report();
333 if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
334 let _ = sh_println!(
335 "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
336 test.contract_name(),
337 test.signature,
338 source_gas,
339 target_gas
340 );
341 has_diff = true;
342 }
343 } else {
344 let _ = sh_println!(
345 "No matching snapshot entry found for \"{}::{}\" in snapshot file",
346 test.contract_name(),
347 test.signature
348 );
349 has_diff = true;
350 }
351 }
352 !has_diff
353}
354
355fn diff(tests: Vec<SuiteTestResult>, snaps: Vec<GasSnapshotEntry>) -> Result<()> {
357 let snaps = snaps
358 .into_iter()
359 .map(|s| ((s.contract_name, s.signature), s.gas_used))
360 .collect::<HashMap<_, _>>();
361 let mut diffs = Vec::with_capacity(tests.len());
362 for test in tests.into_iter() {
363 if let Some(target_gas_used) =
364 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
365 {
366 diffs.push(GasSnapshotDiff {
367 source_gas_used: test.result.kind.report(),
368 signature: test.signature,
369 target_gas_used,
370 });
371 }
372 }
373 let mut overall_gas_change = 0i128;
374 let mut overall_gas_used = 0i128;
375
376 diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
377
378 for diff in diffs {
379 let gas_change = diff.gas_change();
380 overall_gas_change += gas_change;
381 overall_gas_used += diff.target_gas_used.gas() as i128;
382 let gas_diff = diff.gas_diff();
383 sh_println!(
384 "{} (gas: {} ({})) ",
385 diff.signature,
386 fmt_change(gas_change),
387 fmt_pct_change(gas_diff)
388 )?;
389 }
390
391 let overall_gas_diff = overall_gas_change as f64 / overall_gas_used as f64;
392 sh_println!(
393 "Overall gas change: {} ({})",
394 fmt_change(overall_gas_change),
395 fmt_pct_change(overall_gas_diff)
396 )?;
397 Ok(())
398}
399
400fn fmt_pct_change(change: f64) -> String {
401 let change_pct = change * 100.0;
402 match change.total_cmp(&0.0) {
403 Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
404 Ordering::Equal => {
405 format!("{change_pct:.3}%")
406 }
407 Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
408 }
409}
410
411fn fmt_change(change: i128) -> String {
412 match change.cmp(&0) {
413 Ordering::Less => format!("{change}").green().to_string(),
414 Ordering::Equal => {
415 format!("{change}")
416 }
417 Ordering::Greater => format!("{change}").red().to_string(),
418 }
419}
420
421fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
425 if let Some(tolerance) = tolerance_pct {
426 let (hi, lo) = if source_gas > target_gas {
427 (source_gas, target_gas)
428 } else {
429 (target_gas, source_gas)
430 };
431 let diff = (1. - (lo as f64 / hi as f64)) * 100.;
432 diff < tolerance as f64
433 } else {
434 source_gas == target_gas
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_tolerance() {
444 assert!(within_tolerance(100, 105, Some(5)));
445 assert!(within_tolerance(105, 100, Some(5)));
446 assert!(!within_tolerance(100, 106, Some(5)));
447 assert!(!within_tolerance(106, 100, Some(5)));
448 assert!(within_tolerance(100, 100, None));
449 }
450
451 #[test]
452 fn can_parse_basic_gas_snapshot_entry() {
453 let s = "Test:deposit() (gas: 7222)";
454 let entry = GasSnapshotEntry::from_str(s).unwrap();
455 assert_eq!(
456 entry,
457 GasSnapshotEntry {
458 contract_name: "Test".to_string(),
459 signature: "deposit()".to_string(),
460 gas_used: TestKindReport::Unit { gas: 7222 }
461 }
462 );
463 }
464
465 #[test]
466 fn can_parse_fuzz_gas_snapshot_entry() {
467 let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
468 let entry = GasSnapshotEntry::from_str(s).unwrap();
469 assert_eq!(
470 entry,
471 GasSnapshotEntry {
472 contract_name: "Test".to_string(),
473 signature: "deposit()".to_string(),
474 gas_used: TestKindReport::Fuzz { runs: 256, median_gas: 200, mean_gas: 100 }
475 }
476 );
477 }
478
479 #[test]
480 fn can_parse_invariant_gas_snapshot_entry() {
481 let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
482 let entry = GasSnapshotEntry::from_str(s).unwrap();
483 assert_eq!(
484 entry,
485 GasSnapshotEntry {
486 contract_name: "Test".to_string(),
487 signature: "deposit()".to_string(),
488 gas_used: TestKindReport::Invariant {
489 runs: 256,
490 calls: 100,
491 reverts: 200,
492 metrics: HashMap::default(),
493 failed_corpus_replays: 0,
494 }
495 }
496 );
497 }
498
499 #[test]
500 fn can_parse_invariant_gas_snapshot_entry2() {
501 let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
502 let entry = GasSnapshotEntry::from_str(s).unwrap();
503 assert_eq!(
504 entry,
505 GasSnapshotEntry {
506 contract_name: "ERC20Invariants".to_string(),
507 signature: "invariantBalanceSum()".to_string(),
508 gas_used: TestKindReport::Invariant {
509 runs: 256,
510 calls: 3840,
511 reverts: 2388,
512 metrics: HashMap::default(),
513 failed_corpus_replays: 0,
514 }
515 }
516 );
517 }
518}