Skip to content

Commit

Permalink
Modify risc0 tracer util to print function stack to enable better deb…
Browse files Browse the repository at this point in the history
…ugging (Sovereign-Labs#711)

* stack analysis

* fix issues

* some more changes

* minor changes

---------

Co-authored-by: dubbelosix <dub@006.com>
  • Loading branch information
dubbelosix and dubbelosix authored Aug 28, 2023
1 parent 216fb16 commit 211b56f
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 27 deletions.
19 changes: 19 additions & 0 deletions examples/demo-prover/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Directories and paths
TRACER_DIR = ../../utils/zk-cycle-utils/tracer
ELF_PATH_TRACER = ../../../examples/demo-prover/target/riscv-guest/riscv32im-risc0-zkvm-elf/release/rollup
TRACE_PATH_TRACER = ../../../examples/demo-prover/host/rollup.trace

# This allows you to pass additional flags when you call `make run-tracer`.
# For example: `make run-tracer ADDITIONAL_FLAGS="--some-flag"`
ADDITIONAL_FLAGS ?=

.PHONY: generate-files run-tracer

all: generate-files run-tracer

generate-files:
ROLLUP_TRACE=rollup.trace cargo bench --bench prover_bench --features bench

run-tracer:
@cd $(TRACER_DIR) && \
cargo run --release -- --no-raw-counts --rollup-elf $(ELF_PATH_TRACER) --rollup-trace $(TRACE_PATH_TRACER) $(ADDITIONAL_FLAGS)
13 changes: 9 additions & 4 deletions examples/demo-prover/host/benches/prover_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,17 @@ impl RegexAppender {
impl log::Log for RegexAppender {
fn log(&self, record: &log::Record) {
if let Some(captures) = self.regex.captures(record.args().to_string().as_str()) {
let mut file_guard = self.file.lock().unwrap();
if let Some(matched_pc) = captures.get(1) {
let pc_value_num = u64::from_str_radix(&matched_pc.as_str()[2..], 16).unwrap();
let pc_value = format!("{}\n", pc_value_num);
let mut file_guard = self.file.lock().unwrap();
let pc_value = format!("{}\t", pc_value_num);
file_guard.write_all(pc_value.as_bytes()).unwrap();
}
if let Some(matched_iname) = captures.get(2) {
let iname = matched_iname.as_str().to_uppercase();
let iname_value = format!("{}\n", iname);
file_guard.write_all(iname_value.as_bytes()).unwrap();
}
}
}

Expand All @@ -69,8 +74,8 @@ impl log::Log for RegexAppender {
}

fn get_config(rollup_trace: &str) -> Config {
let regex_pattern = r".*?pc: (0x[0-9a-fA-F]+), insn.*";
// let log_file = "/Users/dubbelosix/sovereign/examples/demo-prover/matched_pattern.log";
// [942786] pc: 0x0008e564, insn: 0xffc67613 => andi x12, x12, -4
let regex_pattern = r".*?pc: (0x[0-9a-fA-F]+), insn: .*?=> ([a-z]*?) ";

let custom_appender = RegexAppender::new(regex_pattern, rollup_trace);

Expand Down
2 changes: 1 addition & 1 deletion examples/demo-prover/methods/guest/src/bin/rollup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use celestia::{BlobWithSender, CelestiaHeader};
use const_rollup_config::{ROLLUP_NAMESPACE_RAW, SEQUENCER_DA_ADDRESS};
use demo_stf::app::create_zk_app_template;
use demo_stf::ArrayWitness;

use risc0_adapter::guest::Risc0Guest;
use risc0_zkvm::guest::env;
use sov_rollup_interface::crypto::NoOpHasher;
Expand Down Expand Up @@ -103,7 +104,6 @@ pub fn main() {
let metrics_syscall_name = unsafe {
risc0_zkvm_platform::syscall::SyscallName::from_bytes_with_nul(cycle_string.as_ptr())
};

risc0_zkvm::guest::env::send_recv_slice::<u8, u8>(metrics_syscall_name, &serialized);
}
}
142 changes: 120 additions & 22 deletions utils/zk-cycle-utils/tracer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,69 @@ struct Args {
#[arg(short, long)]
/// Strip the hashes from the function name while printing
strip_hashes: bool,

#[arg(short, long)]
/// Function name to target for getting stack counts
function_name: Option<String>,

#[arg(short, long)]
/// Exclude functions matching these patterns from display
/// usage: -e func1 -e func2 -e func3
exclude_view: Vec<String>,
}

fn strip_hash(name_with_hash: &str) -> String {
let re = Regex::new(r"::h[0-9a-fA-F]+$").unwrap();
re.replace(name_with_hash, "").to_string()
}

fn print_intruction_counts(count_vec: Vec<(&String, &usize)>, top_n: usize, strip_hashes: bool) {
fn get_cycle_count(insn: &str) -> Result<usize, &'static str> {
// The opcodes and their cycle counts are taken from
// https://github.com/risc0/risc0/blob/main/risc0/zkvm/src/host/server/opcode.rs
match insn {
"LB" | "LH" | "LW" | "LBU" | "LHU" | "ADDI" | "SLLI" | "SLTI" | "SLTIU" |
"AUIPC" | "SB" | "SH" | "SW" | "ADD" | "SUB" | "SLL" | "SLT" | "SLTU" |
"XOR" | "SRL" | "SRA" | "OR" | "AND" | "MUL" | "MULH" | "MULSU" | "MULU" |
"LUI" | "BEQ" | "BNE" | "BLT" | "BGE" | "BLTU" | "BGEU" | "JALR" | "JAL" |
"ECALL" | "EBREAK" => Ok(1),

// Don't see this in the risc0 code base, but MUL, MULH, MULSU, and MULU all take 1 cycle,
// so going with that for MULHU as well.
"MULHU" => Ok(1),

"XORI" | "ORI" | "ANDI" | "SRLI" | "SRAI" | "DIV" | "DIVU" | "REM" | "REMU" => Ok(2),

_ => Err("Decode error"),
}
}

fn print_intruction_counts(first_header: &str,
count_vec: Vec<(String, usize)>,
top_n: usize,
strip_hashes: bool,
exclude_list: Option<&[String]>) {
let mut table = Table::new();
table.set_format(*format::consts::FORMAT_DEFAULT);
table.set_titles(Row::new(vec![
Cell::new("Function Name"),
Cell::new(first_header),
Cell::new("Instruction Count"),
]));

let wrap_width = 90;
let mut row_count = 0;
for (key, value) in count_vec {
let mut cont = false;
if let Some(ev) = exclude_list {
for e in ev {
if key.contains(e) {
cont = true;
break
}
}
if cont {
continue
}
}
let mut stripped_key = key.clone();
if strip_hashes {
stripped_key = strip_hash(&key);
Expand All @@ -75,7 +120,18 @@ fn print_intruction_counts(count_vec: Vec<(&String, &usize)>, top_n: usize, stri
table.printstd();
}

fn _build_lookups_radare_2(
fn focused_stack_counts(function_stack: &[String],
filtered_stack_counts: &mut HashMap<Vec<String>, usize>,
function_name: &str,
instruction: &str) {
if let Some(index) = function_stack.iter().position(|s| s == function_name) {
let truncated_stack = &function_stack[0..=index];
let count = filtered_stack_counts.entry(truncated_stack.to_vec()).or_insert(0);
*count += get_cycle_count(instruction).unwrap();
}
}

fn _build_radare2_lookups(
start_lookup: &mut HashMap<u64, String>,
end_lookup: &mut HashMap<u64, String>,
func_range_lookup: &mut HashMap<String, (u64, u64)>,
Expand Down Expand Up @@ -130,10 +186,18 @@ fn build_goblin_lookups(
Ok(())
}

fn increment_stack_counts(instruction_counts: &mut HashMap<String, usize>, function_stack: &[String]) {
for function_name in function_stack {
*instruction_counts.entry(function_name.clone()).or_insert(0) += 1;
fn increment_stack_counts(instruction_counts: &mut HashMap<String, usize>,
function_stack: &[String],
filtered_stack_counts: &mut HashMap<Vec<String>, usize>,
function_name: &Option<String>,
instruction: &str) {
for f in function_stack {
*instruction_counts.entry(f.clone()).or_insert(0) += get_cycle_count(instruction).unwrap();
}
if let Some(f) = function_name {
focused_stack_counts(function_stack, filtered_stack_counts, &f, instruction)
}

}

fn main() -> std::io::Result<()> {
Expand All @@ -145,6 +209,8 @@ fn main() -> std::io::Result<()> {
let no_stack_counts = args.no_stack_counts;
let no_raw_counts = args.no_raw_counts;
let strip_hashes = args.strip_hashes;
let function_name = args.function_name;
let exclude_view = args.exclude_view;

let mut start_lookup = HashMap::new();
let mut end_lookup = HashMap::new();
Expand All @@ -153,7 +219,7 @@ fn main() -> std::io::Result<()> {

let mut function_ranges: Vec<(u64, u64, String)> = func_range_lookup
.iter()
.map(|(function_name, &(start, end))| (start, end, function_name.clone()))
.map(|(f, &(start, end))| (start, end, f.clone()))
.collect();

function_ranges.sort_by_key(|&(start, _, _)| start);
Expand All @@ -162,6 +228,7 @@ fn main() -> std::io::Result<()> {
let mut function_stack: Vec<String> = Vec::new();
let mut instruction_counts: HashMap<String, usize> = HashMap::new();
let mut counts_without_callgraph: HashMap<String, usize> = HashMap::new();
let mut filtered_stack_counts: HashMap<Vec<String>, usize> = HashMap::new();
let total_lines = file_content.lines().count() as u64;
let mut current_function_range : (u64,u64) = (0,0);

Expand All @@ -176,7 +243,9 @@ fn main() -> std::io::Result<()> {
if c % &update_interval == 0 {
pb.inc(update_interval as u64);
}
let pc = line.parse().unwrap();
let mut parts = line.split("\t");
let pc = parts.next().unwrap_or_default().parse().unwrap();
let instruction = parts.next().unwrap_or_default();

// Raw counts without considering the callgraph at all
// we're just checking if the PC belongs to a function
Expand All @@ -193,9 +262,9 @@ fn main() -> std::io::Result<()> {
} })
{
let (_, _, fname) = &function_ranges[index];
*counts_without_callgraph.entry(fname.clone()).or_insert(0) += 1;
*counts_without_callgraph.entry(fname.clone()).or_insert(0) += get_cycle_count(instruction).unwrap();
} else {
*counts_without_callgraph.entry("anonymous".to_string()).or_insert(0) += 1;
*counts_without_callgraph.entry("anonymous".to_string()).or_insert(0) += get_cycle_count(instruction).unwrap();
}

// The next section considers the callstack
Expand All @@ -204,17 +273,17 @@ fn main() -> std::io::Result<()> {

// we are still in the current function
if pc > current_function_range.0 && pc <= current_function_range.1 {
increment_stack_counts(&mut instruction_counts, &function_stack);
increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction);
continue;
}

// jump to a new function (or the same one)
if let Some(function_name) = start_lookup.get(&pc) {
increment_stack_counts(&mut instruction_counts, &function_stack);
if let Some(f) = start_lookup.get(&pc) {
increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction);
// jump to a new function (not recursive)
if !function_stack.contains(&function_name) {
function_stack.push(function_name.clone());
current_function_range = *func_range_lookup.get(function_name).unwrap();
if !function_stack.contains(&f) {
function_stack.push(f.clone());
current_function_range = *func_range_lookup.get(f).unwrap();
}
} else {
// this means pc now points to an instruction that is
Expand All @@ -237,33 +306,62 @@ fn main() -> std::io::Result<()> {
if unwind_found {

function_stack.truncate(unwind_point + 1);
increment_stack_counts(&mut instruction_counts, &function_stack);
increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction);
continue;
}

// if no unwind point has been found, that means we jumped to some random location
// so we'll just increment the counts for everything in the stack
increment_stack_counts(&mut instruction_counts, &function_stack);
increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction);
}

}

pb.finish_with_message("done");

let mut raw_counts: Vec<(&String, &usize)> = instruction_counts.iter().collect();
let mut raw_counts: Vec<(String, usize)> = instruction_counts
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect();
raw_counts.sort_by(|a, b| b.1.cmp(&a.1));

println!("\n\nTotal instructions in trace: {}", total_lines);
if !no_stack_counts {
println!("\n\n Instruction counts considering call graph");
print_intruction_counts(raw_counts, top_n, strip_hashes);
print_intruction_counts("Function Name", raw_counts, top_n, strip_hashes,Some(&exclude_view));
}

let mut raw_counts: Vec<(&String, &usize)> = counts_without_callgraph.iter().collect();
let mut raw_counts: Vec<(String, usize)> = counts_without_callgraph
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect();
raw_counts.sort_by(|a, b| b.1.cmp(&a.1));
if !no_raw_counts {
println!("\n\n Instruction counts ignoring call graph");
print_intruction_counts(raw_counts, top_n, strip_hashes);
print_intruction_counts("Function Name",raw_counts, top_n, strip_hashes,Some(&exclude_view));
}

let mut raw_counts: Vec<(String, usize)> = filtered_stack_counts
.iter()
.map(|(stack, count)| {
let numbered_stack = stack
.iter()
.rev()
.enumerate()
.map(|(index, line)| {
let modified_line = if strip_hashes { strip_hash(line) } else { line.clone() };
format!("({}) {}", index + 1, modified_line)
})
.collect::<Vec<_>>()
.join("\n");
(numbered_stack, *count)
})
.collect();

raw_counts.sort_by(|a, b| b.1.cmp(&a.1));
if let Some(f) = function_name {
println!("\n\n Stack patterns for function '{f}' ");
print_intruction_counts("Function Stack",raw_counts, top_n, strip_hashes,None);
}
Ok(())
}

0 comments on commit 211b56f

Please sign in to comment.