Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
feat(cli): move model architecture to base arg
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 1, 2023
1 parent 6551d1e commit 3be6df4
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"rust-analyzer.cargo.features": ["convert", "quantize"]
"rust-analyzer.cargo.features": ["convert"]
}
179 changes: 71 additions & 108 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use std::{
fmt::Debug,
path::{Path, PathBuf},
};
use std::{fmt::Debug, path::PathBuf};

use clap::{Parser, ValueEnum};
use clap::{Parser, Subcommand, ValueEnum};
use color_eyre::eyre::{Result, WrapErr};
use llm::{
ElementType, InferenceParameters, InferenceSessionParameters, LoadProgress, Model,
Expand All @@ -14,6 +11,25 @@ use rand::SeedableRng;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub enum Args {
/// Use a LLaMA model
Llama {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a BLOOM model
Bloom {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a GPT-2 model
Gpt2 {
#[command(subcommand)]
args: BaseArgs,
},
}

#[derive(Subcommand, Debug)]
pub enum BaseArgs {
#[command()]
/// Use a model to infer the next tokens in a sequence, and exit
Infer(Box<Infer>),
Expand All @@ -36,15 +52,6 @@ pub enum Args {
/// have an extended conversation.
ChatExperimental(Box<Repl>),

#[command(hide = true)]
/// Convert a PyTorch model to a GGML model
///
/// This is *not* fully implemented. This is a starting point for developers
/// to continue at a later stage.
///
/// For reference, see [the PR](https://github.com/rustformers/llama-rs/pull/83).
Convert(Box<Convert>),

/// Quantize a GGML model to 4-bit.
Quantize(Box<Quantize>),
}
Expand Down Expand Up @@ -252,10 +259,6 @@ pub struct ModelLoad {
#[arg(long, short = 'm')]
pub model_path: PathBuf,

/// The model architecture to use.
#[arg(long, short = 'a', default_value_t, value_enum)]
pub model_architecture: ModelArchitecture,

/// Sets the size of the context (in tokens). Allows feeding longer prompts.
/// Note that this affects memory.
///
Expand All @@ -274,109 +277,69 @@ pub struct ModelLoad {
#[arg(long)]
pub no_mmap: bool,
}
#[derive(Parser, Debug, ValueEnum, Clone, Copy, Default)]
pub enum ModelArchitecture {
/// Meta's LLaMA model and derivatives (Vicuna, etc).
#[default]
Llama,
/// OpenAI's GPT2 architecture and derivatives (Cerebras, etc).
Gpt2,
/// The BLOOM model. This is currently disabled as it does not work.
Bloom,
}
impl ModelLoad {
pub fn load(&self) -> Result<Box<dyn Model>> {
pub fn load<M: llm::KnownModel + 'static>(&self) -> Result<Box<dyn Model>> {
let now = std::time::Instant::now();

let prefer_mmap = !self.no_mmap;
let model = self
.load_indirect(
&self.model_path,
!self.no_mmap,
self.num_ctx_tokens,
|progress| match progress {
LoadProgress::HyperparametersLoaded => {
log::debug!("Loaded hyperparameters")
}
LoadProgress::ContextSize { bytes } => log::info!(
"ggml ctx size = {:.2} MB\n",
bytes as f64 / (1024.0 * 1024.0)
),
LoadProgress::PartLoading {
file,
let model = llm::load::<M>(
&self.model_path,
!self.no_mmap,
self.num_ctx_tokens,
|progress| match progress {
LoadProgress::HyperparametersLoaded => {
log::debug!("Loaded hyperparameters")
}
LoadProgress::ContextSize { bytes } => log::info!(
"ggml ctx size = {:.2} MB\n",
bytes as f64 / (1024.0 * 1024.0)
),
LoadProgress::PartLoading {
file,
current_part,
total_parts,
} => {
let current_part = current_part + 1;
log::info!(
"Loading model part {}/{} from '{}' (mmap preferred: {})\n",
current_part,
total_parts,
} => {
let current_part = current_part + 1;
log::info!(
"Loading model part {}/{} from '{}' (mmap preferred: {})\n",
current_part,
total_parts,
file.to_string_lossy(),
prefer_mmap
)
}
LoadProgress::PartTensorLoaded {
current_tensor,
tensor_count,
..
} => {
let current_tensor = current_tensor + 1;
if current_tensor % 8 == 0 {
log::info!("Loaded tensor {current_tensor}/{tensor_count}");
}
}
LoadProgress::PartLoaded {
file,
byte_size,
tensor_count,
} => {
log::info!("Loading of '{}' complete", file.to_string_lossy());
log::info!(
"Model size = {:.2} MB / num tensors = {}",
byte_size as f64 / 1024.0 / 1024.0,
tensor_count
);
file.to_string_lossy(),
prefer_mmap
)
}
LoadProgress::PartTensorLoaded {
current_tensor,
tensor_count,
..
} => {
let current_tensor = current_tensor + 1;
if current_tensor % 8 == 0 {
log::info!("Loaded tensor {current_tensor}/{tensor_count}");
}
},
)
.wrap_err("Could not load model")?;
}
LoadProgress::PartLoaded {
file,
byte_size,
tensor_count,
} => {
log::info!("Loading of '{}' complete", file.to_string_lossy());
log::info!(
"Model size = {:.2} MB / num tensors = {}",
byte_size as f64 / 1024.0 / 1024.0,
tensor_count
);
}
},
)
.wrap_err("Could not load model")?;

log::info!(
"Model fully loaded! Elapsed: {}ms",
now.elapsed().as_millis()
);

Ok(model)
}

fn load_indirect(
&self,
path: &Path,
prefer_mmap: bool,
n_context_tokens: usize,
load_progress_callback: impl FnMut(LoadProgress<'_>),
) -> Result<Box<dyn Model>> {
Ok(match self.model_architecture {
ModelArchitecture::Llama => Box::new(llm::load::<llm::models::Llama>(
path,
prefer_mmap,
n_context_tokens,
load_progress_callback,
)?),
ModelArchitecture::Gpt2 => Box::new(llm::load::<llm::models::Gpt2>(
path,
prefer_mmap,
n_context_tokens,
load_progress_callback,
)?),
ModelArchitecture::Bloom => Box::new(llm::load::<llm::models::Bloom>(
path,
prefer_mmap,
n_context_tokens,
load_progress_callback,
)?),
})
Ok(Box::new(model))
}
}

Expand Down
41 changes: 22 additions & 19 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{convert::Infallible, io::Write};

use clap::Parser;
use cli_args::Args;
use cli_args::{Args, BaseArgs};
use color_eyre::eyre::{Context, Result};
use llm::InferenceError;
use rustyline::error::ReadlineError;
Expand All @@ -17,24 +17,27 @@ fn main() -> Result<()> {
color_eyre::install()?;

let cli_args = Args::parse();
match cli_args {
Args::Infer(args) => infer(&args)?,
Args::DumpTokens(args) => dump_tokens(&args)?,
Args::Repl(args) => interactive(&args, false)?,
Args::ChatExperimental(args) => interactive(&args, true)?,
Args::Convert(args) => {
llm::models::llama::convert::convert_pth_to_ggml(&args.directory, args.file_type.into())
}
Args::Quantize(args) => quantize(&args)?,
match &cli_args {
Args::Llama { args } => handle_args::<llm::models::Llama>(args),
Args::Bloom { args } => handle_args::<llm::models::Bloom>(args),
Args::Gpt2 { args } => handle_args::<llm::models::Gpt2>(args),
}
}

Ok(())
fn handle_args<M: llm::KnownModel + 'static>(args: &cli_args::BaseArgs) -> Result<()> {
match args {
BaseArgs::Infer(args) => infer::<M>(args),
BaseArgs::DumpTokens(args) => dump_tokens::<M>(args),
BaseArgs::Repl(args) => interactive::<M>(args, false),
BaseArgs::ChatExperimental(args) => interactive::<M>(args, true),
BaseArgs::Quantize(args) => quantize::<M>(args),
}
}

fn infer(args: &cli_args::Infer) -> Result<()> {
fn infer<M: llm::KnownModel + 'static>(args: &cli_args::Infer) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_params = args.generate.inference_session_parameters();
let model = args.model_load.load()?;
let model = args.model_load.load::<M>()?;
let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
args.persist_session.as_deref(),
Expand Down Expand Up @@ -83,9 +86,9 @@ fn infer(args: &cli_args::Infer) -> Result<()> {
Ok(())
}

fn dump_tokens(args: &cli_args::DumpTokens) -> Result<()> {
fn dump_tokens<M: llm::KnownModel + 'static>(args: &cli_args::DumpTokens) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let model = args.model_load.load()?;
let model = args.model_load.load::<M>()?;
let toks = match model.vocabulary().tokenize(&prompt, false) {
Ok(toks) => toks,
Err(e) => {
Expand All @@ -112,15 +115,15 @@ fn dump_tokens(args: &cli_args::DumpTokens) -> Result<()> {
Ok(())
}

fn interactive(
fn interactive<M: llm::KnownModel + 'static>(
args: &cli_args::Repl,
// If set to false, the session will be cloned after each inference
// to ensure that previous state is not carried over.
chat_mode: bool,
) -> Result<()> {
let prompt_file = args.prompt_file.contents();
let inference_session_params = args.generate.inference_session_parameters();
let model = args.model_load.load()?;
let model = args.model_load.load::<M>()?;
let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
Expand Down Expand Up @@ -194,10 +197,10 @@ fn interactive(
Ok(())
}

fn quantize(args: &cli_args::Quantize) -> Result<()> {
fn quantize<M: llm::KnownModel + 'static>(args: &cli_args::Quantize) -> Result<()> {
use llm::quantize::QuantizeProgress::*;

llm::quantize::quantize::<llm::models::Llama>(
llm::quantize::quantize::<M>(
&args.source,
&args.destination,
args.target.into(),
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
/// Implemented by model hyperparameters for loading and saving to a GGML model read/writer.
pub trait Hyperparameters: Sized + Default {
/// The error type returned during a failure of [Self::write].
type WriteError: Error + 'static;
type WriteError: Error + Send + Sync + 'static;

/// Read the parameters from a reader.
fn read(reader: &mut dyn BufRead) -> Result<Self, LoadError>;
Expand Down
4 changes: 2 additions & 2 deletions crates/llm-base/src/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub enum QuantizeProgress<'a> {

#[derive(Error, Debug)]
/// Errors encountered during the quantization process.
pub enum QuantizeError<E: std::error::Error> {
pub enum QuantizeError<E: std::error::Error + Send + Sync> {
#[error("could not load model")]
/// There was an error while attempting to load the model.
Load(#[from] LoadError),
Expand Down Expand Up @@ -112,7 +112,7 @@ pub enum QuantizeError<E: std::error::Error> {
#[error("an error was encountered while writing model-specific data")]
WriteError(#[source] E),
}
impl<E: std::error::Error + 'static> QuantizeError<E> {
impl<E: std::error::Error + Send + Sync + 'static> QuantizeError<E> {
pub(crate) fn from_format_error(value: SaveError<QuantizeError<E>>, path: PathBuf) -> Self {
match value {
SaveError::Io(io) => QuantizeError::Io(io),
Expand Down

0 comments on commit 3be6df4

Please sign in to comment.