From 9a222690cd0a9e3322bb6a926d54e90d08fb2c0f Mon Sep 17 00:00:00 2001
From: Philpax
Date: Fri, 30 Jun 2023 03:17:14 +0200
Subject: [PATCH] fix #317 - cli move architecture into subcommands
---
binaries/llm-cli/src/cli_args.rs | 69 +++------
binaries/llm-cli/src/main.rs | 220 +++++++++++++++--------------
crates/llm-base/src/loader.rs | 12 +-
crates/llm-base/src/model/mod.rs | 3 +-
crates/llm/examples/embeddings.rs | 2 +-
crates/llm/examples/inference.rs | 2 +-
crates/llm/examples/vicuna-chat.rs | 2 +-
crates/llm/src/lib.rs | 91 +++++++-----
8 files changed, 208 insertions(+), 193 deletions(-)
diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs
index ce7db33f..0da6e3c5 100644
--- a/binaries/llm-cli/src/cli_args.rs
+++ b/binaries/llm-cli/src/cli_args.rs
@@ -1,6 +1,6 @@
use std::{fmt, ops::Deref, path::PathBuf, sync::Arc};
-use clap::{Parser, Subcommand, ValueEnum};
+use clap::{Parser, ValueEnum};
use color_eyre::eyre::{bail, Result, WrapErr};
use llm::{
ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias,
@@ -11,50 +11,6 @@ use rand::SeedableRng;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub enum Args {
- /// Use a BLOOM model
- Bloom {
- #[command(subcommand)]
- args: BaseArgs,
- },
- /// Use a GPT-2 model
- Gpt2 {
- #[command(subcommand)]
- args: BaseArgs,
- },
- /// Use a GPT-J model
- #[clap(id = "gptj")]
- GptJ {
- #[command(subcommand)]
- args: BaseArgs,
- },
- /// Use a GPT-NeoX model
- #[clap(id = "gptneox")]
- GptNeoX {
- #[command(subcommand)]
- args: BaseArgs,
- },
- /// Use a LLaMA model
- Llama {
- #[command(subcommand)]
- args: BaseArgs,
- },
- /// Use a MPT model
- #[clap(id = "mpt")]
- Mpt {
- #[command(subcommand)]
- args: BaseArgs,
- },
- /// Use a Falcon model
- #[clap(id = "falcon")]
- #[cfg(feature = "falcon")]
- Falcon {
- #[command(subcommand)]
- args: BaseArgs,
- },
-}
-
-#[derive(Subcommand, Debug)]
-pub enum BaseArgs {
#[command()]
/// Use a model to infer the next tokens in a sequence, and exit.
Infer(Box),
@@ -156,7 +112,7 @@ pub struct Info {
pub tensors: bool,
/// Show all of the tokens in the tokenizer.
- #[arg(long, short = 'v')]
+ #[arg(long, short = 'k')]
pub tokenizer: bool,
}
@@ -372,12 +328,22 @@ impl ModelTokenizer {
}
}
+#[derive(Parser, Debug)]
+pub struct ModelArchitecture {
+ /// The model architecture to use. Will attempt to guess if not specified.
+ #[arg(long, short = 'a')]
+ pub model_architecture: Option,
+}
+
#[derive(Parser, Debug)]
pub struct ModelAndTokenizer {
/// Where to load the model from
#[arg(long, short = 'm')]
pub model_path: PathBuf,
+ #[command(flatten)]
+ pub architecture: ModelArchitecture,
+
#[command(flatten)]
pub tokenizer: ModelTokenizer,
}
@@ -415,7 +381,7 @@ pub struct ModelLoad {
pub lora_paths: Option>,
}
impl ModelLoad {
- pub fn load(&self, use_gpu: bool) -> Result> {
+ pub fn load(&self, use_gpu: bool) -> Result> {
let params = ModelParameters {
prefer_mmap: !self.no_mmap,
context_size: self.num_ctx_tokens,
@@ -441,7 +407,8 @@ impl ModelLoad {
}
};
- let model = llm::load::(
+ let model = llm::load_dynamic(
+ self.model_and_tokenizer.architecture.model_architecture,
&self.model_and_tokenizer.model_path,
tokenizer_source,
params,
@@ -496,7 +463,6 @@ impl ModelLoad {
}
},
)
- .map(Box::new)
.wrap_err("Could not load model");
if model.is_err() {
@@ -507,7 +473,7 @@ impl ModelLoad {
}
}
- Ok(model?)
+ model
}
}
@@ -548,6 +514,9 @@ impl PromptFile {
#[derive(Parser, Debug)]
pub struct Quantize {
+ #[command(flatten)]
+ pub architecture: ModelArchitecture,
+
/// The path to the model to quantize
#[arg()]
pub source: PathBuf,
diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs
index 45e4c127..443f6733 100644
--- a/binaries/llm-cli/src/main.rs
+++ b/binaries/llm-cli/src/main.rs
@@ -5,8 +5,8 @@ use std::{
};
use clap::Parser;
-use cli_args::{Args, BaseArgs};
-use color_eyre::eyre::{Context, Result};
+use cli_args::Args;
+use color_eyre::eyre::{Context, ContextCompat, Result};
use llm::{InferenceError, InferenceFeedback, InferenceResponse};
use rustyline::{
error::ReadlineError,
@@ -25,35 +25,22 @@ fn main() -> Result<()> {
.init();
color_eyre::install()?;
- let cli_args = Args::parse();
- match &cli_args {
- Args::Llama { args } => handle_args::(args),
- Args::Bloom { args } => handle_args::(args),
- Args::Gpt2 { args } => handle_args::(args),
- Args::GptJ { args } => handle_args::(args),
- Args::GptNeoX { args } => handle_args::(args),
- Args::Mpt { args } => handle_args::(args),
- #[cfg(feature = "falcon")]
- Args::Falcon { args } => handle_args::(args),
- }
-}
-
-fn handle_args(args: &cli_args::BaseArgs) -> Result<()> {
+ let args = Args::parse();
match args {
- BaseArgs::Infer(args) => infer::(args),
- BaseArgs::Perplexity(args) => perplexity::(args),
- BaseArgs::Info(args) => info::(args),
- BaseArgs::PromptTokens(args) => prompt_tokens::(args),
- BaseArgs::Repl(args) => interactive::(args, false),
- BaseArgs::Chat(args) => interactive::(args, true),
- BaseArgs::Quantize(args) => quantize::(args),
+ Args::Infer(args) => infer(&args),
+ Args::Perplexity(args) => perplexity(&args),
+ Args::Info(args) => info(&args),
+ Args::PromptTokens(args) => prompt_tokens(&args),
+ Args::Repl(args) => interactive(&args, false),
+ Args::Chat(args) => interactive(&args, true),
+ Args::Quantize(args) => quantize(&args),
}
}
-fn infer(args: &cli_args::Infer) -> Result<()> {
+fn infer(args: &cli_args::Infer) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_config = args.generate.inference_session_config();
- let model = args.model_load.load::(args.generate.use_gpu)?;
+ let model = args.model_load.load(args.generate.use_gpu)?;
let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
@@ -118,10 +105,10 @@ fn infer(args: &cli_args::Infer) -> Result<()> {
Ok(())
}
-fn perplexity(args: &cli_args::Perplexity) -> Result<()> {
+fn perplexity(args: &cli_args::Perplexity) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_config = args.generate.inference_session_config();
- let model = args.model_load.load::(args.generate.use_gpu)?;
+ let model = args.model_load.load(args.generate.use_gpu)?;
let (mut session, _) = snapshot::read_or_create_session(
model.as_ref(),
None,
@@ -142,48 +129,62 @@ fn perplexity(args: &cli_args::Perplexity) -> Resu
Ok(())
}
-fn info(args: &cli_args::Info) -> Result<()> {
- let model_path = &args.model_and_tokenizer.model_path;
- let tokenizer = args.model_and_tokenizer.to_source()?.retrieve(model_path)?;
+fn info(args: &cli_args::Info) -> Result<()> {
+ struct InfoVisitor<'a>(&'a cli_args::Info);
+ impl llm::ModelArchitectureVisitor> for InfoVisitor<'_> {
+ fn visit(&mut self) -> Result<()> {
+ let args = self.0;
- let file = File::open(model_path)?;
- let mut reader = BufReader::new(&file);
- let mut loader: llm::Loader = llm::Loader::new(tokenizer, |_| {
- // We purposely do not print progress here, as we are only interested in the metadata
- });
+ let model_path = &args.model_and_tokenizer.model_path;
+ let tokenizer = args.model_and_tokenizer.to_source()?.retrieve(model_path)?;
- llm::ggml_format::load(&mut reader, &mut loader)?;
+ let file = File::open(model_path)?;
+ let mut reader = BufReader::new(&file);
+ let mut loader: llm::Loader =
+ llm::Loader::new(tokenizer, |_| {
+ // We purposely do not print progress here, as we are only interested in the metadata
+ });
- log::info!("Container type: {:?}", loader.container_type);
- log::info!("Hyperparameters: {:?}", loader.hyperparameters);
- log::info!("Tokenizer vocabulary size: {}", loader.tokenizer.len());
+ llm::ggml_format::load(&mut reader, &mut loader)?;
- if args.tokenizer {
- log::info!("Tokens:");
- for i in 0..loader.tokenizer.len() {
- log::info!("- {}: {}", i, utf8_or_array(&loader.tokenizer.token(i)));
- }
- }
+ log::info!("Container type: {:?}", loader.container_type);
+ log::info!("Hyperparameters: {:?}", loader.hyperparameters);
+ log::info!("Tokenizer vocabulary size: {}", loader.tokenizer.len());
- if args.tensors {
- log::info!("Tensors:");
- for (name, tensor) in &loader.tensors {
- log::info!("- {} ({:?} {:?})", name, tensor.element_type, tensor.dims());
- }
- }
+ if args.tokenizer {
+ log::info!("Tokens:");
+ for i in 0..loader.tokenizer.len() {
+ log::info!("- {}: {}", i, utf8_or_array(&loader.tokenizer.token(i)));
+ }
+ }
- fn utf8_or_array(token: &[u8]) -> String {
- std::str::from_utf8(token)
- .map(|s| s.to_owned())
- .unwrap_or(format!("{:?}", token))
+ if args.tensors {
+ log::info!("Tensors:");
+ for (name, tensor) in &loader.tensors {
+ log::info!("- {} ({:?} {:?})", name, tensor.element_type, tensor.dims());
+ }
+ }
+
+ fn utf8_or_array(token: &[u8]) -> String {
+ std::str::from_utf8(token)
+ .map(|s| s.to_owned())
+ .unwrap_or(format!("{:?}", token))
+ }
+
+ Ok(())
+ }
}
- Ok(())
+ args.model_and_tokenizer
+ .architecture
+ .model_architecture
+ .wrap_err("a model architecture is required at present")?
+ .visit(&mut InfoVisitor(args))
}
-fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> {
+fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
- let model = args.model_load.load::(false)?;
+ let model = args.model_load.load(false)?;
let toks = match model.tokenizer().tokenize(&prompt, false) {
Ok(toks) => toks,
Err(e) => {
@@ -222,7 +223,7 @@ fn force_newline_event_seq() -> KeyEvent {
KeyEvent(KeyCode::Enter, Modifiers::SHIFT)
}
-fn interactive(
+fn interactive(
args: &cli_args::Repl,
// If set to false, the session will be cloned after each inference
// to ensure that previous state is not carried over.
@@ -230,7 +231,7 @@ fn interactive(
) -> Result<()> {
let prompt_file = args.prompt_file.contents();
let inference_session_config = args.generate.inference_session_config();
- let model = args.model_load.load::(args.generate.use_gpu)?;
+ let model = args.model_load.load(args.generate.use_gpu)?;
let (mut session, mut session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
@@ -318,51 +319,64 @@ fn interactive(
Ok(())
}
-fn quantize(args: &cli_args::Quantize) -> Result<()> {
+fn quantize(args: &cli_args::Quantize) -> Result<()> {
use llm::QuantizeProgress;
- let mut source = BufReader::new(std::fs::File::open(&args.source)?);
- let mut destination = BufWriter::new(std::fs::File::create(&args.destination)?);
- let tokenizer = args.tokenizer.to_source()?.retrieve(&args.source)?;
-
- llm::quantize::(
- &mut source,
- &mut destination,
- tokenizer,
- args.container_type.into(),
- args.target.into(),
- |progress| match progress {
- QuantizeProgress::HyperparametersLoaded => log::info!("Loaded hyperparameters"),
- QuantizeProgress::TensorLoading {
- name,
- dims,
- element_type,
- n_elements,
- } => log::info!(
- "Loading tensor `{name}` ({n_elements} ({dims:?}) {element_type} elements)"
- ),
- QuantizeProgress::TensorQuantizing { name } => log::info!("Quantizing tensor `{name}`"),
- QuantizeProgress::TensorQuantized {
- name,
- original_size,
- reduced_size,
- history,
- } => log::info!(
- "Quantized tensor `{name}` from {original_size} to {reduced_size} bytes ({history:?})"
- ),
- QuantizeProgress::TensorSkipped { name, size } => {
- log::info!("Skipped tensor `{name}` ({size} bytes)")
- }
- QuantizeProgress::Finished {
- original_size,
- reduced_size,
- history,
- } => log::info!(
- "Finished quantization from {original_size} to {reduced_size} bytes ({history:?})"
- ),
- },
- )
- .wrap_err("failed to quantize model")
+ struct QuantizeVisitor<'a>(&'a cli_args::Quantize);
+ impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> {
+ fn visit(&mut self) -> Result<()> {
+ let args = self.0;
+
+ let mut source: BufReader = BufReader::new(std::fs::File::open(&args.source)?);
+ let mut destination: BufWriter =
+ BufWriter::new(std::fs::File::create(&args.destination)?);
+ let tokenizer: llm::Tokenizer = args.tokenizer.to_source()?.retrieve(&args.source)?;
+
+ llm::quantize::(
+ &mut source,
+ &mut destination,
+ tokenizer,
+ args.container_type.into(),
+ args.target.into(),
+ |progress| match progress {
+ QuantizeProgress::HyperparametersLoaded => log::info!("Loaded hyperparameters"),
+ QuantizeProgress::TensorLoading {
+ name,
+ dims,
+ element_type,
+ n_elements,
+ } => log::info!(
+ "Loading tensor `{name}` ({n_elements} ({dims:?}) {element_type} elements)"
+ ),
+ QuantizeProgress::TensorQuantizing { name } => log::info!("Quantizing tensor `{name}`"),
+ QuantizeProgress::TensorQuantized {
+ name,
+ original_size,
+ reduced_size,
+ history,
+ } => log::info!(
+ "Quantized tensor `{name}` from {original_size} to {reduced_size} bytes ({history:?})"
+ ),
+ QuantizeProgress::TensorSkipped { name, size } => {
+ log::info!("Skipped tensor `{name}` ({size} bytes)")
+ }
+ QuantizeProgress::Finished {
+ original_size,
+ reduced_size,
+ history,
+ } => log::info!(
+ "Finished quantization from {original_size} to {reduced_size} bytes ({history:?})"
+ ),
+ },
+ )
+ .wrap_err("failed to quantize model")
+ }
+ }
+
+ args.architecture
+ .model_architecture
+ .wrap_err("the architecture must be known for quantization")?
+ .visit(&mut QuantizeVisitor(args))
}
fn load_prompt_file_with_prompt(
diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs
index 13b04516..1725535b 100644
--- a/crates/llm-base/src/loader.rs
+++ b/crates/llm-base/src/loader.rs
@@ -327,13 +327,21 @@ pub enum LoadError {
},
/// The tokenizer could not be loaded.
#[error("could not load tokenizer {path:?}: {error}")]
- TokenizerLoadError {
+ TokenizerLoadFail {
/// The invalid tokenizer path
path: PathBuf,
/// The error that occurred.
error: Box,
},
+ /// There is insufficient information to guess the model architecture from the provided file.
+ ///
+ /// A model architecture must be provided to load the model.
+ #[error("could not guess model architecture from {path:?}")]
+ MissingModelArchitecture {
+ /// The path that failed.
+ path: PathBuf,
+ },
}
impl From for LoadError {
fn from(value: util::FindAllModelFilesError) -> Self {
@@ -345,7 +353,7 @@ impl From for LoadError {
}
impl From for LoadError {
fn from(value: TokenizerLoadError) -> Self {
- LoadError::TokenizerLoadError {
+ LoadError::TokenizerLoadFail {
path: value.path,
error: value.error,
}
diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs
index bee50f37..45eb8650 100644
--- a/crates/llm-base/src/model/mod.rs
+++ b/crates/llm-base/src/model/mod.rs
@@ -177,7 +177,8 @@ pub enum HyperparametersWriteError {
InvalidIntegerConversion(#[from] std::num::TryFromIntError),
}
-/// Parameters for tuning model instances
+/// Parameters for model-wide behaviour.
+#[derive(Debug, Clone)]
pub struct ModelParameters {
/// For [GGML formats](ggml::ContainerType) that support it, [mmap](https://en.wikipedia.org/wiki/Mmap)
/// is the default. Although mmap typically improves performance, setting this value to `false` may
diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs
index 74207a1d..0a6a999a 100644
--- a/crates/llm/examples/embeddings.rs
+++ b/crates/llm/examples/embeddings.rs
@@ -51,7 +51,7 @@ fn main() {
// Load model
let model_params = llm::ModelParameters::default();
let model = llm::load_dynamic(
- model_architecture,
+ Some(model_architecture),
&model_path,
tokenizer_source,
model_params,
diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs
index aa740b02..51e7369a 100644
--- a/crates/llm/examples/inference.rs
+++ b/crates/llm/examples/inference.rs
@@ -39,7 +39,7 @@ fn main() {
let now = std::time::Instant::now();
let model = llm::load_dynamic(
- model_architecture,
+ Some(model_architecture),
&model_path,
tokenizer_source,
Default::default(),
diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs
index e08f0be3..7cdeb1d1 100644
--- a/crates/llm/examples/vicuna-chat.rs
+++ b/crates/llm/examples/vicuna-chat.rs
@@ -31,7 +31,7 @@ fn main() {
let model_architecture = args.model_architecture;
let model_path = args.model_path;
let model = llm::load_dynamic(
- model_architecture,
+ Some(model_architecture),
&model_path,
tokenizer_source,
Default::default(),
diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs
index 30ea6c56..c165deb5 100644
--- a/crates/llm/src/lib.rs
+++ b/crates/llm/src/lib.rs
@@ -153,6 +153,33 @@ impl ModelArchitecture {
];
}
+/// Used to dispatch some code based on the model architecture.
+pub trait ModelArchitectureVisitor {
+ /// Visit a model architecture.
+ fn visit(&mut self) -> R;
+}
+impl ModelArchitecture {
+ /// Use a visitor to dispatch some code based on the model architecture.
+ pub fn visit(&self, visitor: &mut impl ModelArchitectureVisitor) -> R {
+ match self {
+ #[cfg(feature = "bloom")]
+ Self::Bloom => visitor.visit::(),
+ #[cfg(feature = "gpt2")]
+ Self::Gpt2 => visitor.visit::(),
+ #[cfg(feature = "gptj")]
+ Self::GptJ => visitor.visit::(),
+ #[cfg(feature = "gptneox")]
+ Self::GptNeoX => visitor.visit::(),
+ #[cfg(feature = "llama")]
+ Self::Llama => visitor.visit::(),
+ #[cfg(feature = "mpt")]
+ Self::Mpt => visitor.visit::(),
+ #[cfg(feature = "falcon")]
+ Self::Falcon => visitor.visit::(),
+ }
+ }
+}
+
/// An unsupported model architecture was specified.
pub struct UnsupportedModelArchitecture(String);
impl Display for UnsupportedModelArchitecture {
@@ -227,18 +254,17 @@ impl Display for ModelArchitecture {
}
/// A helper function that loads the specified model from disk using an architecture
-/// specified at runtime.
+/// specified at runtime. If no architecture is specified, it will try to infer it
+/// from the model's metadata.
///
/// A wrapper around [load] that dispatches to the correct model.
pub fn load_dynamic(
- architecture: ModelArchitecture,
+ architecture: Option,
path: &Path,
tokenizer_source: TokenizerSource,
params: ModelParameters,
load_progress_callback: impl FnMut(LoadProgress),
) -> Result, LoadError> {
- use ModelArchitecture as MA;
-
fn load_model(
path: &Path,
tokenizer_source: TokenizerSource,
@@ -253,38 +279,35 @@ pub fn load_dynamic(
)?))
}
- let model: Box = match architecture {
- #[cfg(feature = "bloom")]
- MA::Bloom => {
- load_model::(path, tokenizer_source, params, load_progress_callback)?
- }
- #[cfg(feature = "gpt2")]
- MA::Gpt2 => {
- load_model::(path, tokenizer_source, params, load_progress_callback)?
- }
- #[cfg(feature = "gptj")]
- MA::GptJ => {
- load_model::(path, tokenizer_source, params, load_progress_callback)?
- }
- #[cfg(feature = "gptneox")]
- MA::GptNeoX => {
- load_model::(path, tokenizer_source, params, load_progress_callback)?
- }
- #[cfg(feature = "llama")]
- MA::Llama => {
- load_model::(path, tokenizer_source, params, load_progress_callback)?
- }
- #[cfg(feature = "mpt")]
- MA::Mpt => {
- load_model::(path, tokenizer_source, params, load_progress_callback)?
- }
- #[cfg(feature = "falcon")]
- MA::Falcon => {
- load_model::(path, tokenizer_source, params, load_progress_callback)?
+ let architecture = architecture.ok_or_else(|| LoadError::MissingModelArchitecture {
+ path: path.to_owned(),
+ })?;
+
+ struct LoadVisitor<'a, F: FnMut(LoadProgress)> {
+ path: &'a Path,
+ tokenizer_source: TokenizerSource,
+ params: ModelParameters,
+ load_progress_callback: F,
+ }
+ impl<'a, F: FnMut(LoadProgress)> ModelArchitectureVisitor, LoadError>>
+ for LoadVisitor<'a, F>
+ {
+ fn visit(&mut self) -> Result, LoadError> {
+ load_model::(
+ self.path,
+ self.tokenizer_source.clone(),
+ self.params.clone(),
+ &mut self.load_progress_callback,
+ )
}
- };
+ }
- Ok(model)
+ architecture.visit(&mut LoadVisitor {
+ path,
+ tokenizer_source,
+ params,
+ load_progress_callback,
+ })
}
#[cfg(test)]