diff --git a/examples/demo-prover/Makefile b/examples/demo-prover/Makefile new file mode 100644 index 000000000..3c6fb360f --- /dev/null +++ b/examples/demo-prover/Makefile @@ -0,0 +1,19 @@ +# Directories and paths +TRACER_DIR = ../../utils/zk-cycle-utils/tracer +ELF_PATH_TRACER = ../../../examples/demo-prover/target/riscv-guest/riscv32im-risc0-zkvm-elf/release/rollup +TRACE_PATH_TRACER = ../../../examples/demo-prover/host/rollup.trace + +# This allows you to pass additional flags when you call `make run-tracer`. +# For example: `make run-tracer ADDITIONAL_FLAGS="--some-flag"` +ADDITIONAL_FLAGS ?= + +.PHONY: generate-files run-tracer + +all: generate-files run-tracer + +generate-files: + ROLLUP_TRACE=rollup.trace cargo bench --bench prover_bench --features bench + +run-tracer: + @cd $(TRACER_DIR) && \ + cargo run --release -- --no-raw-counts --rollup-elf $(ELF_PATH_TRACER) --rollup-trace $(TRACE_PATH_TRACER) $(ADDITIONAL_FLAGS) diff --git a/examples/demo-prover/host/benches/prover_bench.rs b/examples/demo-prover/host/benches/prover_bench.rs index 6c18035b3..fdbdc5db5 100644 --- a/examples/demo-prover/host/benches/prover_bench.rs +++ b/examples/demo-prover/host/benches/prover_bench.rs @@ -52,12 +52,17 @@ impl RegexAppender { impl log::Log for RegexAppender { fn log(&self, record: &log::Record) { if let Some(captures) = self.regex.captures(record.args().to_string().as_str()) { + let mut file_guard = self.file.lock().unwrap(); if let Some(matched_pc) = captures.get(1) { let pc_value_num = u64::from_str_radix(&matched_pc.as_str()[2..], 16).unwrap(); - let pc_value = format!("{}\n", pc_value_num); - let mut file_guard = self.file.lock().unwrap(); + let pc_value = format!("{}\t", pc_value_num); file_guard.write_all(pc_value.as_bytes()).unwrap(); } + if let Some(matched_iname) = captures.get(2) { + let iname = matched_iname.as_str().to_uppercase(); + let iname_value = format!("{}\n", iname); + file_guard.write_all(iname_value.as_bytes()).unwrap(); + } } } @@ -69,8 +74,8 @@ impl log::Log for RegexAppender { } fn get_config(rollup_trace: &str) -> Config { - let regex_pattern = r".*?pc: (0x[0-9a-fA-F]+), insn.*"; - // let log_file = "/Users/dubbelosix/sovereign/examples/demo-prover/matched_pattern.log"; + // [942786] pc: 0x0008e564, insn: 0xffc67613 => andi x12, x12, -4 + let regex_pattern = r".*?pc: (0x[0-9a-fA-F]+), insn: .*?=> ([a-z]*?) "; let custom_appender = RegexAppender::new(regex_pattern, rollup_trace); diff --git a/examples/demo-prover/methods/guest/src/bin/rollup.rs b/examples/demo-prover/methods/guest/src/bin/rollup.rs index 9f0e06573..0f1fe8278 100644 --- a/examples/demo-prover/methods/guest/src/bin/rollup.rs +++ b/examples/demo-prover/methods/guest/src/bin/rollup.rs @@ -11,6 +11,7 @@ use celestia::{BlobWithSender, CelestiaHeader}; use const_rollup_config::{ROLLUP_NAMESPACE_RAW, SEQUENCER_DA_ADDRESS}; use demo_stf::app::create_zk_app_template; use demo_stf::ArrayWitness; + use risc0_adapter::guest::Risc0Guest; use risc0_zkvm::guest::env; use sov_rollup_interface::crypto::NoOpHasher; @@ -103,7 +104,6 @@ pub fn main() { let metrics_syscall_name = unsafe { risc0_zkvm_platform::syscall::SyscallName::from_bytes_with_nul(cycle_string.as_ptr()) }; - risc0_zkvm::guest::env::send_recv_slice::(metrics_syscall_name, &serialized); } } diff --git a/utils/zk-cycle-utils/tracer/src/main.rs b/utils/zk-cycle-utils/tracer/src/main.rs index 4feb167db..2f6e5dfd5 100644 --- a/utils/zk-cycle-utils/tracer/src/main.rs +++ b/utils/zk-cycle-utils/tracer/src/main.rs @@ -38,6 +38,15 @@ struct Args { #[arg(short, long)] /// Strip the hashes from the function name while printing strip_hashes: bool, + + #[arg(short, long)] + /// Function name to target for getting stack counts + function_name: Option, + + #[arg(short, long)] + /// Exclude functions matching these patterns from display + /// usage: -e func1 -e func2 -e func3 + exclude_view: Vec, } fn strip_hash(name_with_hash: &str) -> String { @@ -45,17 +54,53 @@ fn strip_hash(name_with_hash: &str) -> String { re.replace(name_with_hash, "").to_string() } -fn print_intruction_counts(count_vec: Vec<(&String, &usize)>, top_n: usize, strip_hashes: bool) { +fn get_cycle_count(insn: &str) -> Result { + // The opcodes and their cycle counts are taken from + // https://github.com/risc0/risc0/blob/main/risc0/zkvm/src/host/server/opcode.rs + match insn { + "LB" | "LH" | "LW" | "LBU" | "LHU" | "ADDI" | "SLLI" | "SLTI" | "SLTIU" | + "AUIPC" | "SB" | "SH" | "SW" | "ADD" | "SUB" | "SLL" | "SLT" | "SLTU" | + "XOR" | "SRL" | "SRA" | "OR" | "AND" | "MUL" | "MULH" | "MULSU" | "MULU" | + "LUI" | "BEQ" | "BNE" | "BLT" | "BGE" | "BLTU" | "BGEU" | "JALR" | "JAL" | + "ECALL" | "EBREAK" => Ok(1), + + // Don't see this in the risc0 code base, but MUL, MULH, MULSU, and MULU all take 1 cycle, + // so going with that for MULHU as well. + "MULHU" => Ok(1), + + "XORI" | "ORI" | "ANDI" | "SRLI" | "SRAI" | "DIV" | "DIVU" | "REM" | "REMU" => Ok(2), + + _ => Err("Decode error"), + } +} + +fn print_intruction_counts(first_header: &str, + count_vec: Vec<(String, usize)>, + top_n: usize, + strip_hashes: bool, + exclude_list: Option<&[String]>) { let mut table = Table::new(); table.set_format(*format::consts::FORMAT_DEFAULT); table.set_titles(Row::new(vec![ - Cell::new("Function Name"), + Cell::new(first_header), Cell::new("Instruction Count"), ])); let wrap_width = 90; let mut row_count = 0; for (key, value) in count_vec { + let mut cont = false; + if let Some(ev) = exclude_list { + for e in ev { + if key.contains(e) { + cont = true; + break + } + } + if cont { + continue + } + } let mut stripped_key = key.clone(); if strip_hashes { stripped_key = strip_hash(&key); @@ -75,7 +120,18 @@ fn print_intruction_counts(count_vec: Vec<(&String, &usize)>, top_n: usize, stri table.printstd(); } -fn _build_lookups_radare_2( +fn focused_stack_counts(function_stack: &[String], + filtered_stack_counts: &mut HashMap, usize>, + function_name: &str, + instruction: &str) { + if let Some(index) = function_stack.iter().position(|s| s == function_name) { + let truncated_stack = &function_stack[0..=index]; + let count = filtered_stack_counts.entry(truncated_stack.to_vec()).or_insert(0); + *count += get_cycle_count(instruction).unwrap(); + } +} + +fn _build_radare2_lookups( start_lookup: &mut HashMap, end_lookup: &mut HashMap, func_range_lookup: &mut HashMap, @@ -130,10 +186,18 @@ fn build_goblin_lookups( Ok(()) } -fn increment_stack_counts(instruction_counts: &mut HashMap, function_stack: &[String]) { - for function_name in function_stack { - *instruction_counts.entry(function_name.clone()).or_insert(0) += 1; +fn increment_stack_counts(instruction_counts: &mut HashMap, + function_stack: &[String], + filtered_stack_counts: &mut HashMap, usize>, + function_name: &Option, + instruction: &str) { + for f in function_stack { + *instruction_counts.entry(f.clone()).or_insert(0) += get_cycle_count(instruction).unwrap(); + } + if let Some(f) = function_name { + focused_stack_counts(function_stack, filtered_stack_counts, &f, instruction) } + } fn main() -> std::io::Result<()> { @@ -145,6 +209,8 @@ fn main() -> std::io::Result<()> { let no_stack_counts = args.no_stack_counts; let no_raw_counts = args.no_raw_counts; let strip_hashes = args.strip_hashes; + let function_name = args.function_name; + let exclude_view = args.exclude_view; let mut start_lookup = HashMap::new(); let mut end_lookup = HashMap::new(); @@ -153,7 +219,7 @@ fn main() -> std::io::Result<()> { let mut function_ranges: Vec<(u64, u64, String)> = func_range_lookup .iter() - .map(|(function_name, &(start, end))| (start, end, function_name.clone())) + .map(|(f, &(start, end))| (start, end, f.clone())) .collect(); function_ranges.sort_by_key(|&(start, _, _)| start); @@ -162,6 +228,7 @@ fn main() -> std::io::Result<()> { let mut function_stack: Vec = Vec::new(); let mut instruction_counts: HashMap = HashMap::new(); let mut counts_without_callgraph: HashMap = HashMap::new(); + let mut filtered_stack_counts: HashMap, usize> = HashMap::new(); let total_lines = file_content.lines().count() as u64; let mut current_function_range : (u64,u64) = (0,0); @@ -176,7 +243,9 @@ fn main() -> std::io::Result<()> { if c % &update_interval == 0 { pb.inc(update_interval as u64); } - let pc = line.parse().unwrap(); + let mut parts = line.split("\t"); + let pc = parts.next().unwrap_or_default().parse().unwrap(); + let instruction = parts.next().unwrap_or_default(); // Raw counts without considering the callgraph at all // we're just checking if the PC belongs to a function @@ -193,9 +262,9 @@ fn main() -> std::io::Result<()> { } }) { let (_, _, fname) = &function_ranges[index]; - *counts_without_callgraph.entry(fname.clone()).or_insert(0) += 1; + *counts_without_callgraph.entry(fname.clone()).or_insert(0) += get_cycle_count(instruction).unwrap(); } else { - *counts_without_callgraph.entry("anonymous".to_string()).or_insert(0) += 1; + *counts_without_callgraph.entry("anonymous".to_string()).or_insert(0) += get_cycle_count(instruction).unwrap(); } // The next section considers the callstack @@ -204,17 +273,17 @@ fn main() -> std::io::Result<()> { // we are still in the current function if pc > current_function_range.0 && pc <= current_function_range.1 { - increment_stack_counts(&mut instruction_counts, &function_stack); + increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction); continue; } // jump to a new function (or the same one) - if let Some(function_name) = start_lookup.get(&pc) { - increment_stack_counts(&mut instruction_counts, &function_stack); + if let Some(f) = start_lookup.get(&pc) { + increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction); // jump to a new function (not recursive) - if !function_stack.contains(&function_name) { - function_stack.push(function_name.clone()); - current_function_range = *func_range_lookup.get(function_name).unwrap(); + if !function_stack.contains(&f) { + function_stack.push(f.clone()); + current_function_range = *func_range_lookup.get(f).unwrap(); } } else { // this means pc now points to an instruction that is @@ -237,33 +306,62 @@ fn main() -> std::io::Result<()> { if unwind_found { function_stack.truncate(unwind_point + 1); - increment_stack_counts(&mut instruction_counts, &function_stack); + increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction); continue; } // if no unwind point has been found, that means we jumped to some random location // so we'll just increment the counts for everything in the stack - increment_stack_counts(&mut instruction_counts, &function_stack); + increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction); } } pb.finish_with_message("done"); - let mut raw_counts: Vec<(&String, &usize)> = instruction_counts.iter().collect(); + let mut raw_counts: Vec<(String, usize)> = instruction_counts + .iter() + .map(|(key, value)| (key.clone(), value.clone())) + .collect(); raw_counts.sort_by(|a, b| b.1.cmp(&a.1)); println!("\n\nTotal instructions in trace: {}", total_lines); if !no_stack_counts { println!("\n\n Instruction counts considering call graph"); - print_intruction_counts(raw_counts, top_n, strip_hashes); + print_intruction_counts("Function Name", raw_counts, top_n, strip_hashes,Some(&exclude_view)); } - let mut raw_counts: Vec<(&String, &usize)> = counts_without_callgraph.iter().collect(); + let mut raw_counts: Vec<(String, usize)> = counts_without_callgraph + .iter() + .map(|(key, value)| (key.clone(), value.clone())) + .collect(); raw_counts.sort_by(|a, b| b.1.cmp(&a.1)); if !no_raw_counts { println!("\n\n Instruction counts ignoring call graph"); - print_intruction_counts(raw_counts, top_n, strip_hashes); + print_intruction_counts("Function Name",raw_counts, top_n, strip_hashes,Some(&exclude_view)); + } + + let mut raw_counts: Vec<(String, usize)> = filtered_stack_counts + .iter() + .map(|(stack, count)| { + let numbered_stack = stack + .iter() + .rev() + .enumerate() + .map(|(index, line)| { + let modified_line = if strip_hashes { strip_hash(line) } else { line.clone() }; + format!("({}) {}", index + 1, modified_line) + }) + .collect::>() + .join("\n"); + (numbered_stack, *count) + }) + .collect(); + + raw_counts.sort_by(|a, b| b.1.cmp(&a.1)); + if let Some(f) = function_name { + println!("\n\n Stack patterns for function '{f}' "); + print_intruction_counts("Function Stack",raw_counts, top_n, strip_hashes,None); } Ok(()) }