diff --git a/binaries/llm-test/src/main.rs b/binaries/llm-test/src/main.rs index b38ed4a1..5fa33d15 100644 --- a/binaries/llm-test/src/main.rs +++ b/binaries/llm-test/src/main.rs @@ -156,6 +156,9 @@ async fn test_model( download_dir: &Path, results_dir: &Path, ) -> anyhow::Result<()> { + // Load the model + let architecture = llm::ModelArchitecture::from_str(&test_config.architecture)?; + let local_path = if test_config.filename.is_file() { // If this filename points towards a valid file, use it test_config.filename.clone() @@ -173,99 +176,134 @@ async fn test_model( // Download the model if necessary download_file(&test_config.url, &local_path).await?; - let start_time = Instant::now(); - - // Load the model - let architecture = llm::ModelArchitecture::from_str(&test_config.architecture)?; - let model = { - let model = llm::load_dynamic( - Some(architecture), - &local_path, - llm::TokenizerSource::Embedded, - llm::ModelParameters { - prefer_mmap: model_config.mmap, - ..Default::default() - }, - |progress| { - let print = !matches!(&progress, - llm::LoadProgress::TensorLoaded { current_tensor, tensor_count } - if current_tensor % (tensor_count / 10) != 0 + struct TestVisitor<'a> { + model_config: &'a ModelConfig, + test_config: &'a TestConfig, + results_dir: &'a Path, + local_path: &'a Path, + } + impl<'a> llm::ModelArchitectureVisitor> for TestVisitor<'a> { + fn visit(&mut self) -> anyhow::Result<()> { + let Self { + model_config, + test_config, + results_dir, + local_path, + } = *self; + + let start_time = Instant::now(); + + let model = { + let model = llm::load::( + local_path, + llm::TokenizerSource::Embedded, + llm::ModelParameters { + prefer_mmap: model_config.mmap, + ..Default::default() + }, + |progress| { + let print = !matches!(&progress, + llm::LoadProgress::TensorLoaded { current_tensor, tensor_count } + if current_tensor % (tensor_count / 10) != 0 + ); + + if print { + log::info!("loading: {:?}", progress); + } + }, ); - if print { - log::info!("loading: {:?}", progress); + match model { + Ok(m) => m, + Err(err) => { + write_report( + test_config, + results_dir, + &Report::LoadFail { + error: format!("Failed to load model: {}", err), + }, + )?; + + return Err(err.into()); + } } - }, - ); - - match model { - Ok(m) => m, - Err(err) => { - write_report( - test_config, - results_dir, - &Report::LoadFail { - error: format!("Failed to load model: {}", err), - }, - )?; + }; + + log::info!( + "Model fully loaded! Elapsed: {}ms", + start_time.elapsed().as_millis() + ); + + // + // Non-model-specific tests + // + + // Confirm that the model can be sent to a thread, then sent back + let model = tests::can_send(model)?; + + // Confirm that the hyperparameters can be roundtripped + tests::can_roundtrip_hyperparameters(&model)?; + + // + + // + // Model-specific tests + // + + // Run the test cases + let mut test_case_reports = vec![]; + for test_case in &test_config.test_cases { + match test_case { + TestCase::Inference { + input, + output, + maximum_token_count, + } => test_case_reports.push(tests::can_infer( + &model, + model_config, + input, + output, + *maximum_token_count, + )?), + } + } + let first_error: Option = + test_case_reports + .iter() + .find_map(|report: &TestCaseReport| match &report.meta { + TestCaseReportMeta::Error { error } => Some(error.clone()), + _ => None, + }); + + // Save the results + // Serialize the report to a JSON string + write_report( + test_config, + results_dir, + &Report::LoadSuccess { + test_cases: test_case_reports, + }, + )?; - return Err(err.into()); + // Optionally, panic if there was an error + if let Some(err) = first_error { + panic!("Error: {}", err); } - } - }; - log::info!( - "Model fully loaded! Elapsed: {}ms", - start_time.elapsed().as_millis() - ); + log::info!( + "Successfully tested architecture `{}`!", + test_config.architecture + ); - // Confirm that the model can be sent to a thread, then sent back - let model = std::thread::spawn(move || model).join().unwrap(); - - // Run the test cases - let mut test_case_reports = vec![]; - for test_case in &test_config.test_cases { - match test_case { - TestCase::Inference { - input, - output, - maximum_token_count, - } => test_case_reports.push(tests::inference( - model.as_ref(), - model_config, - input, - output, - *maximum_token_count, - )?), + Ok(()) } } - let first_error: Option = - test_case_reports - .iter() - .find_map(|report: &TestCaseReport| match &report.meta { - TestCaseReportMeta::Error { error } => Some(error.clone()), - _ => None, - }); - - // Save the results - // Serialize the report to a JSON string - write_report( + architecture.visit(&mut TestVisitor { + model_config, test_config, results_dir, - &Report::LoadSuccess { - test_cases: test_case_reports, - }, - )?; - - // Optionally, panic if there was an error - if let Some(err) = first_error { - panic!("Error: {}", err); - } - - log::info!( - "Successfully tested architecture `{}`!", - test_config.architecture - ); + local_path: &local_path, + })?; Ok(()) } @@ -283,7 +321,33 @@ fn write_report( mod tests { use super::*; - pub(super) fn inference( + + pub(super) fn can_send(model: M) -> anyhow::Result { + std::thread::spawn(move || model) + .join() + .map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}")) + } + + pub(super) fn can_roundtrip_hyperparameters( + model: &M, + ) -> anyhow::Result<()> { + fn test_hyperparameters( + hyperparameters: &M, + ) -> anyhow::Result<()> { + let mut data = vec![]; + hyperparameters.write_ggml(&mut data)?; + let new_hyperparameters = + ::read_ggml(&mut std::io::Cursor::new(data))?; + + assert_eq!(hyperparameters, &new_hyperparameters); + + Ok(()) + } + + test_hyperparameters(model.hyperparameters()) + } + + pub(super) fn can_infer( model: &dyn llm::Model, model_config: &ModelConfig, input: &str, diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 45eb8650..2de02828 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -65,6 +65,9 @@ pub trait KnownModel: Send + Sync { output_request: &mut OutputRequest, ); + /// Get the hyperparameters for this model. + fn hyperparameters(&self) -> &Self::Hyperparameters; + /// Get the tokenizer for this model. fn tokenizer(&self) -> &Tokenizer; @@ -150,7 +153,7 @@ impl> Model for M { /// Implemented by model hyperparameters for interacting with hyperparameters /// without knowing what they are, as well as writing/reading them as required. -pub trait Hyperparameters: Sized + Default + Debug { +pub trait Hyperparameters: Sized + Default + Debug + PartialEq + Eq { /// Read the parameters in GGML format from a reader. fn read_ggml(reader: &mut dyn BufRead) -> Result; diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 194c9d28..291503a5 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -78,8 +78,8 @@ use std::{ // This is the "user-facing" API, and GGML may not always be our backend. pub use llm_base::{ feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout, - quantize, samplers, ElementType, FileType, FileTypeFormat, InferenceError, InferenceFeedback, - InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession, + quantize, samplers, ElementType, FileType, FileTypeFormat, Hyperparameters, InferenceError, + InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, Sampler, diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index 18cd5e5b..4e9aa192 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -369,6 +369,10 @@ impl KnownModel for Bloom { common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); } + fn hyperparameters(&self) -> &Self::Hyperparameters { + &self.hyperparameters + } + fn tokenizer(&self) -> &Tokenizer { &self.tokenizer } diff --git a/crates/models/falcon/src/lib.rs b/crates/models/falcon/src/lib.rs index 8ee37453..2ac269ef 100644 --- a/crates/models/falcon/src/lib.rs +++ b/crates/models/falcon/src/lib.rs @@ -328,6 +328,10 @@ impl KnownModel for Falcon { common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); } + fn hyperparameters(&self) -> &Self::Hyperparameters { + &self.hyperparameters + } + fn tokenizer(&self) -> &Tokenizer { &self.tokenizer } @@ -354,7 +358,7 @@ impl KnownModel for Falcon { } /// Falcon [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) -#[derive(Debug, Default, PartialEq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Clone, Copy, Eq)] pub struct Hyperparameters { /// Size of the model's vocabulary n_vocab: usize, diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index 1b2427a5..646d3a98 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -323,6 +323,10 @@ impl KnownModel for Gpt2 { common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); } + fn hyperparameters(&self) -> &Self::Hyperparameters { + &self.hyperparameters + } + fn tokenizer(&self) -> &Tokenizer { &self.tokenizer } diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index 92ee8f4a..195f876a 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -291,6 +291,10 @@ impl KnownModel for GptJ { common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); } + fn hyperparameters(&self) -> &Self::Hyperparameters { + &self.hyperparameters + } + fn tokenizer(&self) -> &Tokenizer { &self.tokenizer } diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 84a5c417..5339b901 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -337,6 +337,10 @@ impl KnownModel for GptNeoX { common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); } + fn hyperparameters(&self) -> &Self::Hyperparameters { + &self.hyperparameters + } + fn tokenizer(&self) -> &Tokenizer { &self.tokenizer } diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 025352fd..b6375647 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -321,6 +321,10 @@ impl KnownModel for Llama { common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len); } + fn hyperparameters(&self) -> &Self::Hyperparameters { + &self.hyperparameters + } + fn tokenizer(&self) -> &Tokenizer { &self.tokenizer } diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 56e129e4..18991adf 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -271,6 +271,10 @@ impl KnownModel for Mpt { common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n); } + fn hyperparameters(&self) -> &Self::Hyperparameters { + &self.hyperparameters + } + fn tokenizer(&self) -> &Tokenizer { &self.tokenizer } @@ -316,6 +320,7 @@ pub struct Hyperparameters { /// file_type file_type: FileType, } +impl Eq for Hyperparameters {} impl llm_base::Hyperparameters for Hyperparameters { fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result {