Skip to content

Commit

Permalink
refactor(test): switch to env_logger
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jul 1, 2023
1 parent 038e218 commit 1155495
Showing 5 changed files with 74 additions and 49 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ license = "MIT OR Apache-2.0"
[workspace.dependencies]
bytemuck = "1.13.1"
bytesize = "1.1"
env_logger = "0.10.0"
log = "0.4"
rand = "0.8.5"
thiserror = "1.0"
2 changes: 1 addition & 1 deletion binaries/llm-cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -15,14 +15,14 @@ path = "src/main.rs"
llm = { path = "../../crates/llm", version = "0.2.0-dev" }

bytesize = { workspace = true }
env_logger = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
rustyline = { workspace = true }
spinoff = { workspace = true }
clap = { workspace = true }

bincode = "1.3.3"
env_logger = "0.10.0"
num_cpus = "1.15.0"

color-eyre = { version = "0.6.2", default-features = false }
2 changes: 2 additions & 0 deletions binaries/llm-test/Cargo.toml
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@ llm = { path = "../../crates/llm", version = "0.2.0-dev" }

anyhow = { workspace = true }
clap = { workspace = true }
env_logger = { workspace = true }
log = { workspace = true }
rand = { workspace = true }

reqwest = "0.11.9"
116 changes: 68 additions & 48 deletions binaries/llm-test/src/main.rs
Original file line number Diff line number Diff line change
@@ -34,6 +34,9 @@ struct Cli {

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Set up the logger
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();

// Parse command line arguments
let args = Cli::parse();
let specific_model = args.architecture.clone();
@@ -50,13 +53,15 @@ async fn main() -> anyhow::Result<()> {

// Load configurations
let mut test_cases = HashMap::new();
for entry in fs::read_dir(configs_dir)? {
let path = entry?.path();
if path.is_file() && path.extension().map_or(false, |ext| ext == "json") {
let file_name = path.file_stem().unwrap().to_string_lossy().to_string();
let test_case: TestCase = serde_json::from_str(&fs::read_to_string(&path)?)?;
test_cases.insert(file_name, test_case);
}
for path in fs::read_dir(configs_dir)?
.filter_map(Result::ok)
.map(|de| de.path())
.filter(|p| p.is_file())
.filter(|p| p.extension().unwrap_or_default() == "json")
{
let file_name = path.file_stem().unwrap().to_string_lossy().to_string();
let test_case: TestCase = serde_json::from_str(&fs::read_to_string(&path)?)?;
test_cases.insert(file_name, test_case);
}
let model_config = ModelConfig {
mmap: !args.no_mmap,
@@ -78,12 +83,16 @@ async fn main() -> anyhow::Result<()> {
test_cases
};

let test_cases_len = test_cases.len();
for (key, test_case) in test_cases {
println!("Key: {key}, Config: {test_case:?}");
log::info!("Key: {key}, Config: {test_case:?}");
test_model(&model_config, &test_case, &download_dir, &results_dir).await?;
if test_cases_len > 1 {
log::info!("----");
}
}

println!("All tests passed!");
log::info!("All tests passed!");
Ok(())
}

@@ -100,7 +109,7 @@ struct TestCase {

#[derive(Serialize)]
pub struct Report {
pub could_loaded: bool,
pub could_load: bool,
pub inference_stats: Option<InferenceStats>,
pub error: Option<String>,
pub output: String,
@@ -112,7 +121,7 @@ async fn test_model(
download_dir: &Path,
results_dir: &Path,
) -> anyhow::Result<()> {
println!("Testing architecture: `{}` ...", test_case.architecture);
log::info!("Testing architecture: `{}`", test_case.architecture);

let local_path = if test_case.filename.is_file() {
// If this filename points towards a valid file, use it
@@ -129,41 +138,52 @@ async fn test_model(

// Load the model
let architecture = llm::ModelArchitecture::from_str(&test_case.architecture)?;
let model_result = llm::load_dynamic(
Some(architecture),
&local_path,
llm::TokenizerSource::Embedded,
llm::ModelParameters {
prefer_mmap: config.mmap,
..Default::default()
},
llm::load_progress_callback_stdout,
);

let model = match model_result {
Ok(m) => m,
Err(err) => {
// Create a report with could_loaded set to false
let report = Report {
could_loaded: false,
inference_stats: None,
error: Some(format!("Failed to load model: {}", err)),
output: String::new(),
};

// Serialize the report to a JSON string
let json_report = serde_json::to_string(&report)?;
let report_path = results_dir.join(format!("{}.json", test_case.architecture));

// Write the JSON report to a file
fs::write(report_path, json_report)?;

// Optionally, you can return early or decide how to proceed
return Err(err.into());
let model = {
let model = llm::load_dynamic(
Some(architecture),
&local_path,
llm::TokenizerSource::Embedded,
llm::ModelParameters {
prefer_mmap: config.mmap,
..Default::default()
},
|progress| {
let print = !matches!(&progress,
llm::LoadProgress::TensorLoaded { current_tensor, tensor_count }
if current_tensor % (tensor_count / 10) != 0
);

if print {
log::info!("{:?}", progress);
}
},
);

match model {
Ok(m) => m,
Err(err) => {
// Create a report with could_load set to false
let report = Report {
could_load: false,
inference_stats: None,
error: Some(format!("Failed to load model: {}", err)),
output: String::new(),
};

// Serialize the report to a JSON string
let json_report = serde_json::to_string(&report)?;
let report_path = results_dir.join(format!("{}.json", test_case.architecture));

// Write the JSON report to a file
fs::write(report_path, json_report)?;

// Optionally, you can return early or decide how to proceed
return Err(err.into());
}
}
};

println!(
log::info!(
"Model fully loaded! Elapsed: {}ms",
start_time.elapsed().as_millis()
);
@@ -175,7 +195,7 @@ async fn test_model(
let mut rng: StdRng = SeedableRng::seed_from_u64(42);
let mut output = String::new();

println!("Running inference...");
log::info!("Running inference...");
let res = session.infer::<Infallible>(
model.as_ref(),
&mut rng,
@@ -198,7 +218,7 @@ async fn test_model(
_ => Ok(llm::InferenceFeedback::Continue),
},
);
println!("Inference done!");
log::info!("Inference done!");

// Process the results
let (inference_results, error) = match res {
@@ -208,7 +228,7 @@ async fn test_model(

// Save the results
let report = Report {
could_loaded: true,
could_load: true,
inference_stats: inference_results,
error: error.map(|e| format!("{:?}", e)),
output,
@@ -226,7 +246,7 @@ async fn test_model(
panic!("Error: {}", err);
}

println!(
log::info!(
"Successfully tested architecture `{}`!",
test_case.architecture
);
@@ -236,7 +256,7 @@ async fn test_model(

async fn download_file(url: &str, local_path: &PathBuf) -> anyhow::Result<()> {
if Path::new(local_path).exists() {
println!("Model already exists at {}", local_path.to_string_lossy());
log::info!("Model already exists at {}", local_path.to_string_lossy());
return Ok(());
}

0 comments on commit 1155495

Please sign in to comment.