From 9a222690cd0a9e3322bb6a926d54e90d08fb2c0f Mon Sep 17 00:00:00 2001 From: Philpax Date: Fri, 30 Jun 2023 03:17:14 +0200 Subject: [PATCH] fix #317 - cli move architecture into subcommands --- binaries/llm-cli/src/cli_args.rs | 69 +++------ binaries/llm-cli/src/main.rs | 220 +++++++++++++++-------------- crates/llm-base/src/loader.rs | 12 +- crates/llm-base/src/model/mod.rs | 3 +- crates/llm/examples/embeddings.rs | 2 +- crates/llm/examples/inference.rs | 2 +- crates/llm/examples/vicuna-chat.rs | 2 +- crates/llm/src/lib.rs | 91 +++++++----- 8 files changed, 208 insertions(+), 193 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index ce7db33f..0da6e3c5 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -1,6 +1,6 @@ use std::{fmt, ops::Deref, path::PathBuf, sync::Arc}; -use clap::{Parser, Subcommand, ValueEnum}; +use clap::{Parser, ValueEnum}; use color_eyre::eyre::{bail, Result, WrapErr}; use llm::{ ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, @@ -11,50 +11,6 @@ use rand::SeedableRng; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] pub enum Args { - /// Use a BLOOM model - Bloom { - #[command(subcommand)] - args: BaseArgs, - }, - /// Use a GPT-2 model - Gpt2 { - #[command(subcommand)] - args: BaseArgs, - }, - /// Use a GPT-J model - #[clap(id = "gptj")] - GptJ { - #[command(subcommand)] - args: BaseArgs, - }, - /// Use a GPT-NeoX model - #[clap(id = "gptneox")] - GptNeoX { - #[command(subcommand)] - args: BaseArgs, - }, - /// Use a LLaMA model - Llama { - #[command(subcommand)] - args: BaseArgs, - }, - /// Use a MPT model - #[clap(id = "mpt")] - Mpt { - #[command(subcommand)] - args: BaseArgs, - }, - /// Use a Falcon model - #[clap(id = "falcon")] - #[cfg(feature = "falcon")] - Falcon { - #[command(subcommand)] - args: BaseArgs, - }, -} - -#[derive(Subcommand, Debug)] -pub enum BaseArgs { #[command()] /// Use a model to infer the next tokens in a sequence, and exit. Infer(Box), @@ -156,7 +112,7 @@ pub struct Info { pub tensors: bool, /// Show all of the tokens in the tokenizer. - #[arg(long, short = 'v')] + #[arg(long, short = 'k')] pub tokenizer: bool, } @@ -372,12 +328,22 @@ impl ModelTokenizer { } } +#[derive(Parser, Debug)] +pub struct ModelArchitecture { + /// The model architecture to use. Will attempt to guess if not specified. + #[arg(long, short = 'a')] + pub model_architecture: Option, +} + #[derive(Parser, Debug)] pub struct ModelAndTokenizer { /// Where to load the model from #[arg(long, short = 'm')] pub model_path: PathBuf, + #[command(flatten)] + pub architecture: ModelArchitecture, + #[command(flatten)] pub tokenizer: ModelTokenizer, } @@ -415,7 +381,7 @@ pub struct ModelLoad { pub lora_paths: Option>, } impl ModelLoad { - pub fn load(&self, use_gpu: bool) -> Result> { + pub fn load(&self, use_gpu: bool) -> Result> { let params = ModelParameters { prefer_mmap: !self.no_mmap, context_size: self.num_ctx_tokens, @@ -441,7 +407,8 @@ impl ModelLoad { } }; - let model = llm::load::( + let model = llm::load_dynamic( + self.model_and_tokenizer.architecture.model_architecture, &self.model_and_tokenizer.model_path, tokenizer_source, params, @@ -496,7 +463,6 @@ impl ModelLoad { } }, ) - .map(Box::new) .wrap_err("Could not load model"); if model.is_err() { @@ -507,7 +473,7 @@ impl ModelLoad { } } - Ok(model?) + model } } @@ -548,6 +514,9 @@ impl PromptFile { #[derive(Parser, Debug)] pub struct Quantize { + #[command(flatten)] + pub architecture: ModelArchitecture, + /// The path to the model to quantize #[arg()] pub source: PathBuf, diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 45e4c127..443f6733 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -5,8 +5,8 @@ use std::{ }; use clap::Parser; -use cli_args::{Args, BaseArgs}; -use color_eyre::eyre::{Context, Result}; +use cli_args::Args; +use color_eyre::eyre::{Context, ContextCompat, Result}; use llm::{InferenceError, InferenceFeedback, InferenceResponse}; use rustyline::{ error::ReadlineError, @@ -25,35 +25,22 @@ fn main() -> Result<()> { .init(); color_eyre::install()?; - let cli_args = Args::parse(); - match &cli_args { - Args::Llama { args } => handle_args::(args), - Args::Bloom { args } => handle_args::(args), - Args::Gpt2 { args } => handle_args::(args), - Args::GptJ { args } => handle_args::(args), - Args::GptNeoX { args } => handle_args::(args), - Args::Mpt { args } => handle_args::(args), - #[cfg(feature = "falcon")] - Args::Falcon { args } => handle_args::(args), - } -} - -fn handle_args(args: &cli_args::BaseArgs) -> Result<()> { + let args = Args::parse(); match args { - BaseArgs::Infer(args) => infer::(args), - BaseArgs::Perplexity(args) => perplexity::(args), - BaseArgs::Info(args) => info::(args), - BaseArgs::PromptTokens(args) => prompt_tokens::(args), - BaseArgs::Repl(args) => interactive::(args, false), - BaseArgs::Chat(args) => interactive::(args, true), - BaseArgs::Quantize(args) => quantize::(args), + Args::Infer(args) => infer(&args), + Args::Perplexity(args) => perplexity(&args), + Args::Info(args) => info(&args), + Args::PromptTokens(args) => prompt_tokens(&args), + Args::Repl(args) => interactive(&args, false), + Args::Chat(args) => interactive(&args, true), + Args::Quantize(args) => quantize(&args), } } -fn infer(args: &cli_args::Infer) -> Result<()> { +fn infer(args: &cli_args::Infer) -> Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); let inference_session_config = args.generate.inference_session_config(); - let model = args.model_load.load::(args.generate.use_gpu)?; + let model = args.model_load.load(args.generate.use_gpu)?; let (mut session, session_loaded) = snapshot::read_or_create_session( model.as_ref(), @@ -118,10 +105,10 @@ fn infer(args: &cli_args::Infer) -> Result<()> { Ok(()) } -fn perplexity(args: &cli_args::Perplexity) -> Result<()> { +fn perplexity(args: &cli_args::Perplexity) -> Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); let inference_session_config = args.generate.inference_session_config(); - let model = args.model_load.load::(args.generate.use_gpu)?; + let model = args.model_load.load(args.generate.use_gpu)?; let (mut session, _) = snapshot::read_or_create_session( model.as_ref(), None, @@ -142,48 +129,62 @@ fn perplexity(args: &cli_args::Perplexity) -> Resu Ok(()) } -fn info(args: &cli_args::Info) -> Result<()> { - let model_path = &args.model_and_tokenizer.model_path; - let tokenizer = args.model_and_tokenizer.to_source()?.retrieve(model_path)?; +fn info(args: &cli_args::Info) -> Result<()> { + struct InfoVisitor<'a>(&'a cli_args::Info); + impl llm::ModelArchitectureVisitor> for InfoVisitor<'_> { + fn visit(&mut self) -> Result<()> { + let args = self.0; - let file = File::open(model_path)?; - let mut reader = BufReader::new(&file); - let mut loader: llm::Loader = llm::Loader::new(tokenizer, |_| { - // We purposely do not print progress here, as we are only interested in the metadata - }); + let model_path = &args.model_and_tokenizer.model_path; + let tokenizer = args.model_and_tokenizer.to_source()?.retrieve(model_path)?; - llm::ggml_format::load(&mut reader, &mut loader)?; + let file = File::open(model_path)?; + let mut reader = BufReader::new(&file); + let mut loader: llm::Loader = + llm::Loader::new(tokenizer, |_| { + // We purposely do not print progress here, as we are only interested in the metadata + }); - log::info!("Container type: {:?}", loader.container_type); - log::info!("Hyperparameters: {:?}", loader.hyperparameters); - log::info!("Tokenizer vocabulary size: {}", loader.tokenizer.len()); + llm::ggml_format::load(&mut reader, &mut loader)?; - if args.tokenizer { - log::info!("Tokens:"); - for i in 0..loader.tokenizer.len() { - log::info!("- {}: {}", i, utf8_or_array(&loader.tokenizer.token(i))); - } - } + log::info!("Container type: {:?}", loader.container_type); + log::info!("Hyperparameters: {:?}", loader.hyperparameters); + log::info!("Tokenizer vocabulary size: {}", loader.tokenizer.len()); - if args.tensors { - log::info!("Tensors:"); - for (name, tensor) in &loader.tensors { - log::info!("- {} ({:?} {:?})", name, tensor.element_type, tensor.dims()); - } - } + if args.tokenizer { + log::info!("Tokens:"); + for i in 0..loader.tokenizer.len() { + log::info!("- {}: {}", i, utf8_or_array(&loader.tokenizer.token(i))); + } + } - fn utf8_or_array(token: &[u8]) -> String { - std::str::from_utf8(token) - .map(|s| s.to_owned()) - .unwrap_or(format!("{:?}", token)) + if args.tensors { + log::info!("Tensors:"); + for (name, tensor) in &loader.tensors { + log::info!("- {} ({:?} {:?})", name, tensor.element_type, tensor.dims()); + } + } + + fn utf8_or_array(token: &[u8]) -> String { + std::str::from_utf8(token) + .map(|s| s.to_owned()) + .unwrap_or(format!("{:?}", token)) + } + + Ok(()) + } } - Ok(()) + args.model_and_tokenizer + .architecture + .model_architecture + .wrap_err("a model architecture is required at present")? + .visit(&mut InfoVisitor(args)) } -fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> { +fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); - let model = args.model_load.load::(false)?; + let model = args.model_load.load(false)?; let toks = match model.tokenizer().tokenize(&prompt, false) { Ok(toks) => toks, Err(e) => { @@ -222,7 +223,7 @@ fn force_newline_event_seq() -> KeyEvent { KeyEvent(KeyCode::Enter, Modifiers::SHIFT) } -fn interactive( +fn interactive( args: &cli_args::Repl, // If set to false, the session will be cloned after each inference // to ensure that previous state is not carried over. @@ -230,7 +231,7 @@ fn interactive( ) -> Result<()> { let prompt_file = args.prompt_file.contents(); let inference_session_config = args.generate.inference_session_config(); - let model = args.model_load.load::(args.generate.use_gpu)?; + let model = args.model_load.load(args.generate.use_gpu)?; let (mut session, mut session_loaded) = snapshot::read_or_create_session( model.as_ref(), None, @@ -318,51 +319,64 @@ fn interactive( Ok(()) } -fn quantize(args: &cli_args::Quantize) -> Result<()> { +fn quantize(args: &cli_args::Quantize) -> Result<()> { use llm::QuantizeProgress; - let mut source = BufReader::new(std::fs::File::open(&args.source)?); - let mut destination = BufWriter::new(std::fs::File::create(&args.destination)?); - let tokenizer = args.tokenizer.to_source()?.retrieve(&args.source)?; - - llm::quantize::( - &mut source, - &mut destination, - tokenizer, - args.container_type.into(), - args.target.into(), - |progress| match progress { - QuantizeProgress::HyperparametersLoaded => log::info!("Loaded hyperparameters"), - QuantizeProgress::TensorLoading { - name, - dims, - element_type, - n_elements, - } => log::info!( - "Loading tensor `{name}` ({n_elements} ({dims:?}) {element_type} elements)" - ), - QuantizeProgress::TensorQuantizing { name } => log::info!("Quantizing tensor `{name}`"), - QuantizeProgress::TensorQuantized { - name, - original_size, - reduced_size, - history, - } => log::info!( - "Quantized tensor `{name}` from {original_size} to {reduced_size} bytes ({history:?})" - ), - QuantizeProgress::TensorSkipped { name, size } => { - log::info!("Skipped tensor `{name}` ({size} bytes)") - } - QuantizeProgress::Finished { - original_size, - reduced_size, - history, - } => log::info!( - "Finished quantization from {original_size} to {reduced_size} bytes ({history:?})" - ), - }, - ) - .wrap_err("failed to quantize model") + struct QuantizeVisitor<'a>(&'a cli_args::Quantize); + impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> { + fn visit(&mut self) -> Result<()> { + let args = self.0; + + let mut source: BufReader = BufReader::new(std::fs::File::open(&args.source)?); + let mut destination: BufWriter = + BufWriter::new(std::fs::File::create(&args.destination)?); + let tokenizer: llm::Tokenizer = args.tokenizer.to_source()?.retrieve(&args.source)?; + + llm::quantize::( + &mut source, + &mut destination, + tokenizer, + args.container_type.into(), + args.target.into(), + |progress| match progress { + QuantizeProgress::HyperparametersLoaded => log::info!("Loaded hyperparameters"), + QuantizeProgress::TensorLoading { + name, + dims, + element_type, + n_elements, + } => log::info!( + "Loading tensor `{name}` ({n_elements} ({dims:?}) {element_type} elements)" + ), + QuantizeProgress::TensorQuantizing { name } => log::info!("Quantizing tensor `{name}`"), + QuantizeProgress::TensorQuantized { + name, + original_size, + reduced_size, + history, + } => log::info!( + "Quantized tensor `{name}` from {original_size} to {reduced_size} bytes ({history:?})" + ), + QuantizeProgress::TensorSkipped { name, size } => { + log::info!("Skipped tensor `{name}` ({size} bytes)") + } + QuantizeProgress::Finished { + original_size, + reduced_size, + history, + } => log::info!( + "Finished quantization from {original_size} to {reduced_size} bytes ({history:?})" + ), + }, + ) + .wrap_err("failed to quantize model") + } + } + + args.architecture + .model_architecture + .wrap_err("the architecture must be known for quantization")? + .visit(&mut QuantizeVisitor(args)) } fn load_prompt_file_with_prompt( diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 13b04516..1725535b 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -327,13 +327,21 @@ pub enum LoadError { }, /// The tokenizer could not be loaded. #[error("could not load tokenizer {path:?}: {error}")] - TokenizerLoadError { + TokenizerLoadFail { /// The invalid tokenizer path path: PathBuf, /// The error that occurred. error: Box, }, + /// There is insufficient information to guess the model architecture from the provided file. + /// + /// A model architecture must be provided to load the model. + #[error("could not guess model architecture from {path:?}")] + MissingModelArchitecture { + /// The path that failed. + path: PathBuf, + }, } impl From for LoadError { fn from(value: util::FindAllModelFilesError) -> Self { @@ -345,7 +353,7 @@ impl From for LoadError { } impl From for LoadError { fn from(value: TokenizerLoadError) -> Self { - LoadError::TokenizerLoadError { + LoadError::TokenizerLoadFail { path: value.path, error: value.error, } diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index bee50f37..45eb8650 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -177,7 +177,8 @@ pub enum HyperparametersWriteError { InvalidIntegerConversion(#[from] std::num::TryFromIntError), } -/// Parameters for tuning model instances +/// Parameters for model-wide behaviour. +#[derive(Debug, Clone)] pub struct ModelParameters { /// For [GGML formats](ggml::ContainerType) that support it, [mmap](https://en.wikipedia.org/wiki/Mmap) /// is the default. Although mmap typically improves performance, setting this value to `false` may diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs index 74207a1d..0a6a999a 100644 --- a/crates/llm/examples/embeddings.rs +++ b/crates/llm/examples/embeddings.rs @@ -51,7 +51,7 @@ fn main() { // Load model let model_params = llm::ModelParameters::default(); let model = llm::load_dynamic( - model_architecture, + Some(model_architecture), &model_path, tokenizer_source, model_params, diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs index aa740b02..51e7369a 100644 --- a/crates/llm/examples/inference.rs +++ b/crates/llm/examples/inference.rs @@ -39,7 +39,7 @@ fn main() { let now = std::time::Instant::now(); let model = llm::load_dynamic( - model_architecture, + Some(model_architecture), &model_path, tokenizer_source, Default::default(), diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index e08f0be3..7cdeb1d1 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -31,7 +31,7 @@ fn main() { let model_architecture = args.model_architecture; let model_path = args.model_path; let model = llm::load_dynamic( - model_architecture, + Some(model_architecture), &model_path, tokenizer_source, Default::default(), diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 30ea6c56..c165deb5 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -153,6 +153,33 @@ impl ModelArchitecture { ]; } +/// Used to dispatch some code based on the model architecture. +pub trait ModelArchitectureVisitor { + /// Visit a model architecture. + fn visit(&mut self) -> R; +} +impl ModelArchitecture { + /// Use a visitor to dispatch some code based on the model architecture. + pub fn visit(&self, visitor: &mut impl ModelArchitectureVisitor) -> R { + match self { + #[cfg(feature = "bloom")] + Self::Bloom => visitor.visit::(), + #[cfg(feature = "gpt2")] + Self::Gpt2 => visitor.visit::(), + #[cfg(feature = "gptj")] + Self::GptJ => visitor.visit::(), + #[cfg(feature = "gptneox")] + Self::GptNeoX => visitor.visit::(), + #[cfg(feature = "llama")] + Self::Llama => visitor.visit::(), + #[cfg(feature = "mpt")] + Self::Mpt => visitor.visit::(), + #[cfg(feature = "falcon")] + Self::Falcon => visitor.visit::(), + } + } +} + /// An unsupported model architecture was specified. pub struct UnsupportedModelArchitecture(String); impl Display for UnsupportedModelArchitecture { @@ -227,18 +254,17 @@ impl Display for ModelArchitecture { } /// A helper function that loads the specified model from disk using an architecture -/// specified at runtime. +/// specified at runtime. If no architecture is specified, it will try to infer it +/// from the model's metadata. /// /// A wrapper around [load] that dispatches to the correct model. pub fn load_dynamic( - architecture: ModelArchitecture, + architecture: Option, path: &Path, tokenizer_source: TokenizerSource, params: ModelParameters, load_progress_callback: impl FnMut(LoadProgress), ) -> Result, LoadError> { - use ModelArchitecture as MA; - fn load_model( path: &Path, tokenizer_source: TokenizerSource, @@ -253,38 +279,35 @@ pub fn load_dynamic( )?)) } - let model: Box = match architecture { - #[cfg(feature = "bloom")] - MA::Bloom => { - load_model::(path, tokenizer_source, params, load_progress_callback)? - } - #[cfg(feature = "gpt2")] - MA::Gpt2 => { - load_model::(path, tokenizer_source, params, load_progress_callback)? - } - #[cfg(feature = "gptj")] - MA::GptJ => { - load_model::(path, tokenizer_source, params, load_progress_callback)? - } - #[cfg(feature = "gptneox")] - MA::GptNeoX => { - load_model::(path, tokenizer_source, params, load_progress_callback)? - } - #[cfg(feature = "llama")] - MA::Llama => { - load_model::(path, tokenizer_source, params, load_progress_callback)? - } - #[cfg(feature = "mpt")] - MA::Mpt => { - load_model::(path, tokenizer_source, params, load_progress_callback)? - } - #[cfg(feature = "falcon")] - MA::Falcon => { - load_model::(path, tokenizer_source, params, load_progress_callback)? + let architecture = architecture.ok_or_else(|| LoadError::MissingModelArchitecture { + path: path.to_owned(), + })?; + + struct LoadVisitor<'a, F: FnMut(LoadProgress)> { + path: &'a Path, + tokenizer_source: TokenizerSource, + params: ModelParameters, + load_progress_callback: F, + } + impl<'a, F: FnMut(LoadProgress)> ModelArchitectureVisitor, LoadError>> + for LoadVisitor<'a, F> + { + fn visit(&mut self) -> Result, LoadError> { + load_model::( + self.path, + self.tokenizer_source.clone(), + self.params.clone(), + &mut self.load_progress_callback, + ) } - }; + } - Ok(model) + architecture.visit(&mut LoadVisitor { + path, + tokenizer_source, + params, + load_progress_callback, + }) } #[cfg(test)]