Skip to content

Commit

Permalink
fix rustformers#317 - cli move architecture into subcommands
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jun 30, 2023
1 parent 7e2f2bf commit 9a22269
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 193 deletions.
69 changes: 19 additions & 50 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{fmt, ops::Deref, path::PathBuf, sync::Arc};

use clap::{Parser, Subcommand, ValueEnum};
use clap::{Parser, ValueEnum};
use color_eyre::eyre::{bail, Result, WrapErr};
use llm::{
ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias,
Expand All @@ -11,50 +11,6 @@ use rand::SeedableRng;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub enum Args {
/// Use a BLOOM model
Bloom {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a GPT-2 model
Gpt2 {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a GPT-J model
#[clap(id = "gptj")]
GptJ {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a GPT-NeoX model
#[clap(id = "gptneox")]
GptNeoX {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a LLaMA model
Llama {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a MPT model
#[clap(id = "mpt")]
Mpt {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a Falcon model
#[clap(id = "falcon")]
#[cfg(feature = "falcon")]
Falcon {
#[command(subcommand)]
args: BaseArgs,
},
}

#[derive(Subcommand, Debug)]
pub enum BaseArgs {
#[command()]
/// Use a model to infer the next tokens in a sequence, and exit.
Infer(Box<Infer>),
Expand Down Expand Up @@ -156,7 +112,7 @@ pub struct Info {
pub tensors: bool,

/// Show all of the tokens in the tokenizer.
#[arg(long, short = 'v')]
#[arg(long, short = 'k')]
pub tokenizer: bool,
}

Expand Down Expand Up @@ -372,12 +328,22 @@ impl ModelTokenizer {
}
}

#[derive(Parser, Debug)]
pub struct ModelArchitecture {
/// The model architecture to use. Will attempt to guess if not specified.
#[arg(long, short = 'a')]
pub model_architecture: Option<llm::ModelArchitecture>,
}

#[derive(Parser, Debug)]
pub struct ModelAndTokenizer {
/// Where to load the model from
#[arg(long, short = 'm')]
pub model_path: PathBuf,

#[command(flatten)]
pub architecture: ModelArchitecture,

#[command(flatten)]
pub tokenizer: ModelTokenizer,
}
Expand Down Expand Up @@ -415,7 +381,7 @@ pub struct ModelLoad {
pub lora_paths: Option<Vec<PathBuf>>,
}
impl ModelLoad {
pub fn load<M: llm::KnownModel + 'static>(&self, use_gpu: bool) -> Result<Box<dyn Model>> {
pub fn load(&self, use_gpu: bool) -> Result<Box<dyn Model>> {
let params = ModelParameters {
prefer_mmap: !self.no_mmap,
context_size: self.num_ctx_tokens,
Expand All @@ -441,7 +407,8 @@ impl ModelLoad {
}
};

let model = llm::load::<M>(
let model = llm::load_dynamic(
self.model_and_tokenizer.architecture.model_architecture,
&self.model_and_tokenizer.model_path,
tokenizer_source,
params,
Expand Down Expand Up @@ -496,7 +463,6 @@ impl ModelLoad {
}
},
)
.map(Box::new)
.wrap_err("Could not load model");

if model.is_err() {
Expand All @@ -507,7 +473,7 @@ impl ModelLoad {
}
}

Ok(model?)
model
}
}

Expand Down Expand Up @@ -548,6 +514,9 @@ impl PromptFile {

#[derive(Parser, Debug)]
pub struct Quantize {
#[command(flatten)]
pub architecture: ModelArchitecture,

/// The path to the model to quantize
#[arg()]
pub source: PathBuf,
Expand Down
Loading

0 comments on commit 9a22269

Please sign in to comment.