Skip to main content

forge/
gas_report.rs

1//! Gas reports.
2
3use crate::{
4    constants::{CHEATCODE_ADDRESS, HARDHAT_CONSOLE_ADDRESS},
5    traces::{CallTraceArena, CallTraceDecoder, CallTraceNode, DecodedCallData},
6};
7use alloy_primitives::map::HashSet;
8use comfy_table::{Cell, Color, Table, modifiers::UTF8_ROUND_CORNERS};
9use foundry_common::{
10    TestFunctionExt, calc,
11    reports::{ReportKind, report_kind},
12};
13use foundry_evm::traces::CallKind;
14
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17use std::{collections::BTreeMap, fmt::Display};
18
19/// Represents the gas report for a set of contracts.
20#[derive(Clone, Debug, Default, Serialize, Deserialize)]
21pub struct GasReport {
22    /// Whether to report any contracts.
23    report_any: bool,
24    /// What kind of report to generate.
25    report_kind: ReportKind,
26    /// Contracts to generate the report for.
27    report_for: HashSet<String>,
28    /// Contracts to ignore when generating the report.
29    ignore: HashSet<String>,
30    /// Whether to include gas reports for tests.
31    include_tests: bool,
32    /// All contracts that were analyzed grouped by their identifier
33    /// ``test/Counter.t.sol:CounterTest
34    pub contracts: BTreeMap<String, ContractInfo>,
35}
36
37impl GasReport {
38    pub fn new(
39        report_for: impl IntoIterator<Item = String>,
40        ignore: impl IntoIterator<Item = String>,
41        include_tests: bool,
42    ) -> Self {
43        let report_for = report_for.into_iter().collect::<HashSet<_>>();
44        let ignore = ignore.into_iter().collect::<HashSet<_>>();
45        let report_any = report_for.is_empty() || report_for.contains("*");
46        Self {
47            report_any,
48            report_kind: report_kind(),
49            report_for,
50            ignore,
51            include_tests,
52            ..Default::default()
53        }
54    }
55
56    /// Whether the given contract should be reported.
57    #[instrument(level = "trace", skip(self), ret)]
58    fn should_report(&self, contract_name: &str) -> bool {
59        if self.ignore.contains(contract_name) {
60            let contains_anyway = self.report_for.contains(contract_name);
61            if contains_anyway {
62                // If the user listed the contract in 'gas_reports' (the foundry.toml field) a
63                // report for the contract is generated even if it's listed in the ignore
64                // list. This is addressed this way because getting a report you don't expect is
65                // preferable than not getting one you expect. A warning is printed to stderr
66                // indicating the "double listing".
67                let _ = sh_warn!(
68                    "{contract_name} is listed in both 'gas_reports' and 'gas_reports_ignore'."
69                );
70            }
71            return contains_anyway;
72        }
73        self.report_any || self.report_for.contains(contract_name)
74    }
75
76    /// Analyzes the given traces and generates a gas report.
77    pub async fn analyze(
78        &mut self,
79        arenas: impl IntoIterator<Item = &CallTraceArena>,
80        decoder: &CallTraceDecoder,
81    ) {
82        for node in arenas.into_iter().flat_map(|arena| arena.nodes()) {
83            self.analyze_node(node, decoder).await;
84        }
85    }
86
87    async fn analyze_node(&mut self, node: &CallTraceNode, decoder: &CallTraceDecoder) {
88        let trace = &node.trace;
89
90        if trace.address == CHEATCODE_ADDRESS || trace.address == HARDHAT_CONSOLE_ADDRESS {
91            return;
92        }
93
94        let Some(name) = decoder.contracts.get(&node.trace.address) else { return };
95        let contract_name = name.rsplit(':').next().unwrap_or(name);
96
97        if !self.should_report(contract_name) {
98            return;
99        }
100        let contract_info = self.contracts.entry(name.to_string()).or_default();
101        let is_create_call = trace.kind.is_any_create();
102
103        // Record contract deployment size.
104        if is_create_call {
105            trace!(contract_name, "adding create size info");
106            contract_info.size = trace.data.len();
107        }
108
109        // Only include top-level calls which account for calldata and base (21.000) cost.
110        // Only include Calls and Creates as only these calls are isolated in inspector.
111        if trace.depth > 1 && (trace.kind == CallKind::Call || is_create_call) {
112            return;
113        }
114
115        let decoded = || decoder.decode_function(&node.trace);
116
117        if is_create_call {
118            trace!(contract_name, "adding create gas info");
119            contract_info.gas = trace.gas_used;
120        } else if let Some(DecodedCallData { signature, .. }) = decoded().await.call_data {
121            let name = signature.split('(').next().unwrap();
122            // ignore any test/setup functions
123            if self.include_tests || !name.test_function_kind().is_known() {
124                trace!(contract_name, signature, "adding gas info");
125                let gas_info = contract_info
126                    .functions
127                    .entry(name.to_string())
128                    .or_default()
129                    .entry(signature.clone())
130                    .or_default();
131                gas_info.frames.push(trace.gas_used);
132            }
133        }
134    }
135
136    /// Finalizes the gas report by calculating the min, max, mean, and median for each function.
137    #[must_use]
138    pub fn finalize(mut self) -> Self {
139        trace!("finalizing gas report");
140        for contract in self.contracts.values_mut() {
141            for sigs in contract.functions.values_mut() {
142                for func in sigs.values_mut() {
143                    func.frames.sort_unstable();
144                    func.min = func.frames.first().copied().unwrap_or_default();
145                    func.max = func.frames.last().copied().unwrap_or_default();
146                    func.mean = calc::mean(&func.frames);
147                    func.median = calc::median_sorted(&func.frames);
148                    func.calls = func.frames.len() as u64;
149                }
150            }
151        }
152        self
153    }
154}
155
156impl Display for GasReport {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
158        match self.report_kind {
159            ReportKind::Text => {
160                for (name, contract) in &self.contracts {
161                    if contract.functions.is_empty() {
162                        trace!(name, "gas report contract without functions");
163                        continue;
164                    }
165
166                    let table = self.format_table_output(contract, name);
167                    writeln!(f, "\n{table}")?;
168                }
169            }
170            ReportKind::JSON => {
171                writeln!(f, "{}", &self.format_json_output())?;
172            }
173        }
174
175        Ok(())
176    }
177}
178
179impl GasReport {
180    fn format_json_output(&self) -> String {
181        serde_json::to_string(
182            &self
183                .contracts
184                .iter()
185                .filter_map(|(name, contract)| {
186                    if contract.functions.is_empty() {
187                        trace!(name, "gas report contract without functions");
188                        return None;
189                    }
190
191                    let functions = contract
192                        .functions
193                        .iter()
194                        .flat_map(|(_, sigs)| {
195                            sigs.iter().map(|(sig, gas_info)| {
196                                let display_name = sig.replace(':', "");
197                                (display_name, gas_info)
198                            })
199                        })
200                        .collect::<BTreeMap<_, _>>();
201
202                    Some(json!({
203                        "contract": name,
204                        "deployment": {
205                            "gas": contract.gas,
206                            "size": contract.size,
207                        },
208                        "functions": functions,
209                    }))
210                })
211                .collect::<Vec<_>>(),
212        )
213        .unwrap()
214    }
215
216    fn format_table_output(&self, contract: &ContractInfo, name: &str) -> Table {
217        let mut table = Table::new();
218        table.apply_modifier(UTF8_ROUND_CORNERS);
219
220        table.set_header(vec![Cell::new(format!("{name} Contract")).fg(Color::Magenta)]);
221
222        table.add_row(vec![
223            Cell::new("Deployment Cost").fg(Color::Cyan),
224            Cell::new("Deployment Size").fg(Color::Cyan),
225        ]);
226        table.add_row(vec![
227            Cell::new(contract.gas.to_string()),
228            Cell::new(contract.size.to_string()),
229        ]);
230
231        // Add a blank row to separate deployment info from function info.
232        table.add_row(vec![Cell::new("")]);
233
234        table.add_row(vec![
235            Cell::new("Function Name"),
236            Cell::new("Min").fg(Color::Green),
237            Cell::new("Avg").fg(Color::Yellow),
238            Cell::new("Median").fg(Color::Yellow),
239            Cell::new("Max").fg(Color::Red),
240            Cell::new("# Calls").fg(Color::Cyan),
241        ]);
242
243        contract.functions.iter().for_each(|(fname, sigs)| {
244            sigs.iter().for_each(|(sig, gas_info)| {
245                // Show function signature if overloaded else display function name.
246                let display_name =
247                    if sigs.len() == 1 { fname.to_string() } else { sig.replace(':', "") };
248
249                table.add_row(vec![
250                    Cell::new(display_name),
251                    Cell::new(gas_info.min.to_string()).fg(Color::Green),
252                    Cell::new(gas_info.mean.to_string()).fg(Color::Yellow),
253                    Cell::new(gas_info.median.to_string()).fg(Color::Yellow),
254                    Cell::new(gas_info.max.to_string()).fg(Color::Red),
255                    Cell::new(gas_info.calls.to_string()),
256                ]);
257            })
258        });
259
260        table
261    }
262}
263
264#[derive(Clone, Debug, Default, Serialize, Deserialize)]
265pub struct ContractInfo {
266    pub gas: u64,
267    pub size: usize,
268    /// Function name -> Function signature -> GasInfo
269    pub functions: BTreeMap<String, BTreeMap<String, GasInfo>>,
270}
271
272#[derive(Clone, Debug, Default, Serialize, Deserialize)]
273pub struct GasInfo {
274    pub calls: u64,
275    pub min: u64,
276    pub mean: u64,
277    pub median: u64,
278    pub max: u64,
279
280    #[serde(skip)]
281    pub frames: Vec<u64>,
282}