Skip to content

Commit

Permalink
Merge branch 'main' into dfo/model/mpt
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed May 14, 2023
2 parents 0ae4c8d + f523cbf commit dbcd407
Show file tree
Hide file tree
Showing 28 changed files with 861 additions and 271 deletions.
14 changes: 14 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@
"args": ["gptneox", "${env:HOME}/.ggml-models/stablelm-base-alpha-3b.bin"],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug RedPajama Inference",
"cargo": {
"args": ["build", "--example=inference", "--package=llm"],
"filter": {
"name": "inference",
"kind": "example"
}
},
"args": ["redpajama", "${env:HOME}/.ggml-models/redpajama-incite-7b.bin"],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
Expand Down
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ log = "0.4"
rand = "0.8.5"
rustyline = { version = "11.0.0", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0" }
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }
thiserror = "1.0"

Expand Down
36 changes: 21 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,25 @@ The primary crate is the `llm` crate, which wraps `llm-base` and supported model
crates.

On top of `llm`, there is a CLI application, `llm-cli`, which provides a
convenient interface for running inference on supported models. Inferencing
can be done as a one-off, or interactively, through REPL or chat modes. It
can also print information about, or quantize, a GGML model. It can be
downloaded from [the latest GitHub release](https://github.com/rustformers/llm/releases)
or by installing it from `crates.io`.
convenient interface for running inference on supported models. Inferencing can
be done as a one-off, or interactively, through REPL or chat modes. It can also
print information about, or quantize, a GGML model. It can be downloaded from
[the latest GitHub release](https://github.com/rustformers/llm/releases) or by
installing it from `crates.io`.

`llm` is powered by the [`ggml`](https://github.com/ggerganov/ggml) tensor library,
and aims to bring the robustness and ease of use of Rust to the world of large
language models. At present, inference is only on the CPU, but we hope to
support GPU inference in the future through alternate backends.
`llm` is powered by the [`ggml`](https://github.com/ggerganov/ggml) tensor
library, and aims to bring the robustness and ease of use of Rust to the world
of large language models. At present, inference is only on the CPU, but we hope
to support GPU inference in the future through alternate backends.

Currently, the following models are supported:

- [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)
- [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)
- [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama): LLaMA, Alpaca, Vicuna, Koala, GPT4All v1, GPT4-X, Wizard
- [GPT-NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox): GPT-NeoX, StableLM, Dolly v2 (partial, not the same tensor names?)
- [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama): LLaMA,
Alpaca, Vicuna, Koala, GPT4All v1, GPT4-X, Wizard
- [GPT-NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox):
GPT-NeoX, StableLM, RedPajama, Dolly v2
- [BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom): BLOOMZ
- [MPT](https://www.mosaicml.com/blog/mpt-7b)

Expand Down Expand Up @@ -91,7 +93,8 @@ This is useful for development.

### Getting models

GGML files are easy to acquire. For a list of models that have been tested, see the [known-good models](./known-good-models.md).
GGML files are easy to acquire. For a list of models that have been tested, see
the [known-good models](./known-good-models.md).

Certain older GGML formats are not supported by this project, but the goal is to
maintain feature parity with the upstream GGML project. For problems relating to
Expand Down Expand Up @@ -134,7 +137,8 @@ python3 scripts/convert-pth-to-ggml.py /path/to/your/models/7B/ 1
cargo run --release llama quantize /path/to/your/models/7B/ggml-model-f16.bin /path/to/your/models/7B/ggml-model-q4_0.bin q4_0
```

In future, we hope to provide [a more streamlined way of converting models](https://github.com/rustformers/llm/issues/21).
In future, we hope to provide
[a more streamlined way of converting models](https://github.com/rustformers/llm/issues/21).

> **Note**
>
Expand Down Expand Up @@ -230,8 +234,10 @@ use and deploy as any other Rust crate.

#### Applications

- [llmcord](https://github.com/rustformers/llmcord): Discord bot for generating messages using `llm`.
- [llmcord](https://github.com/rustformers/llmcord): Discord bot for generating
messages using `llm`.

#### Libraries

- [llm-chain](https://github.com/sobelio/llm-chain): Work in progress, see [this PR](https://github.com/sobelio/llm-chain/pull/116).
- [llm-chain](https://github.com/sobelio/llm-chain): Build chains in large
language models for text summarization and completion of more complex tasks
122 changes: 78 additions & 44 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ pub enum Args {
NeoX {
#[command(subcommand)]
args: BaseArgs,

#[arg(long)]
/// By default, the GPT-NeoX architecture uses a parallel residual.
///
/// This flag disables that, as some models out there are trained without it,
/// and the model format does not store this information.
no_parallel_residual: bool,
},
/// Use a model from the RedPajama GPT-NeoX family
///
/// (GPT-NeoX with `use_parallel_residual` set to false)
#[clap(id = "redpajama")]
RedPajama {
#[command(subcommand)]
args: BaseArgs,
},
/// Use a MPT model
#[clap(id = "mpt")]
Expand Down Expand Up @@ -316,12 +331,20 @@ pub struct ModelLoad {
/// Don't use mmap to load the model.
#[arg(long)]
pub no_mmap: bool,

/// LoRA adapter to use for the model
#[arg(long)]
pub lora_path: Option<PathBuf>,
}
impl ModelLoad {
pub fn load<M: llm::KnownModel + 'static>(&self) -> Result<Box<dyn Model>> {
pub fn load<M: llm::KnownModel + 'static>(
&self,
overrides: Option<M::Overrides>,
) -> Result<Box<dyn Model>> {
let params = ModelParameters {
prefer_mmap: !self.no_mmap,
n_context_tokens: self.num_ctx_tokens,
lora_adapter: self.lora_path.clone(),
..Default::default()
};

Expand All @@ -333,49 +356,60 @@ impl ModelLoad {
let now = std::time::Instant::now();
let mut prev_load_time = now;

let model = llm::load::<M>(&self.model_path, params, move |progress| match progress {
LoadProgress::HyperparametersLoaded => {
if let Some(sp) = sp.as_mut() {
sp.update_text("Loaded hyperparameters")
};
}
LoadProgress::ContextSize { bytes } => log::debug!(
"ggml ctx size = {}",
bytesize::to_string(bytes as u64, false)
),
LoadProgress::TensorLoaded {
current_tensor,
tensor_count,
..
} => {
if prev_load_time.elapsed().as_millis() > 500 {
// We don't want to re-render this on every message, as that causes the
// spinner to constantly reset and not look like it's spinning (and
// it's obviously wasteful).
if let Some(sp) = sp.as_mut() {
sp.update_text(format!(
"Loaded tensor {}/{}",
current_tensor + 1,
tensor_count
));
};
prev_load_time = std::time::Instant::now();
}
}
LoadProgress::Loaded {
file_size,
tensor_count,
} => {
if let Some(sp) = sp.take() {
sp.success(&format!(
"Loaded {tensor_count} tensors ({}) after {}ms",
bytesize::to_string(file_size, false),
now.elapsed().as_millis()
));
};
}
})
.wrap_err("Could not load model")?;
let model =
llm::load::<M>(
&self.model_path,
params,
overrides,
move |progress| match progress {
LoadProgress::HyperparametersLoaded => {
if let Some(sp) = sp.as_mut() {
sp.update_text("Loaded hyperparameters")
};
}
LoadProgress::ContextSize { bytes } => log::debug!(
"ggml ctx size = {}",
bytesize::to_string(bytes as u64, false)
),
LoadProgress::LoraApplied { name } => {
if let Some(sp) = sp.as_mut() {
sp.update_text(format!("Patched tensor {} via LoRA", name));
}
}
LoadProgress::TensorLoaded {
current_tensor,
tensor_count,
..
} => {
if prev_load_time.elapsed().as_millis() > 500 {
// We don't want to re-render this on every message, as that causes the
// spinner to constantly reset and not look like it's spinning (and
// it's obviously wasteful).
if let Some(sp) = sp.as_mut() {
sp.update_text(format!(
"Loaded tensor {}/{}",
current_tensor + 1,
tensor_count
));
};
prev_load_time = std::time::Instant::now();
}
}
LoadProgress::Loaded {
file_size,
tensor_count,
} => {
if let Some(sp) = sp.take() {
sp.success(&format!(
"Loaded {tensor_count} tensors ({}) after {}ms",
bytesize::to_string(file_size, false),
now.elapsed().as_millis()
));
};
}
},
)
.wrap_err("Could not load model")?;

Ok(Box::new(model))
}
Expand Down
49 changes: 35 additions & 14 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,50 @@ fn main() -> Result<()> {

let cli_args = Args::parse();
match &cli_args {
Args::Llama { args } => handle_args::<llm::models::Llama>(args),
Args::Bloom { args } => handle_args::<llm::models::Bloom>(args),
Args::Gpt2 { args } => handle_args::<llm::models::Gpt2>(args),
Args::GptJ { args } => handle_args::<llm::models::GptJ>(args),
Args::NeoX { args } => handle_args::<llm::models::NeoX>(args),
Args::Mpt { args } => handle_args::<llm::models::Mpt>(args),
Args::Llama { args } => handle_args::<llm::models::Llama>(args, None),
Args::Bloom { args } => handle_args::<llm::models::Bloom>(args, None),
Args::Gpt2 { args } => handle_args::<llm::models::Gpt2>(args, None),
Args::GptJ { args } => handle_args::<llm::models::GptJ>(args, None),
Args::NeoX {
args,
no_parallel_residual,
} => handle_args::<llm::models::NeoX>(
args,
Some(llm::models::NeoXOverrides {
use_parallel_residual: !*no_parallel_residual,
}),
),
Args::RedPajama { args } => handle_args::<llm::models::NeoX>(
args,
Some(llm::models::NeoXOverrides {
use_parallel_residual: false,
}),
),
Args::Mpt { args } => handle_args::<llm::models::Mpt>(args, None),
}
}

fn handle_args<M: llm::KnownModel + 'static>(args: &cli_args::BaseArgs) -> Result<()> {
fn handle_args<M: llm::KnownModel + 'static>(
args: &cli_args::BaseArgs,
overrides: Option<M::Overrides>,
) -> Result<()> {
match args {
BaseArgs::Infer(args) => infer::<M>(args),
BaseArgs::Infer(args) => infer::<M>(args, overrides),
BaseArgs::Info(args) => info::<M>(args),
BaseArgs::PromptTokens(args) => prompt_tokens::<M>(args),
BaseArgs::Repl(args) => interactive::<M>(args, false),
BaseArgs::Chat(args) => interactive::<M>(args, true),
BaseArgs::Repl(args) => interactive::<M>(args, overrides, false),
BaseArgs::Chat(args) => interactive::<M>(args, overrides, true),
BaseArgs::Quantize(args) => quantize::<M>(args),
}
}

fn infer<M: llm::KnownModel + 'static>(args: &cli_args::Infer) -> Result<()> {
fn infer<M: llm::KnownModel + 'static>(
args: &cli_args::Infer,
overrides: Option<M::Overrides>,
) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>()?;
let model = args.model_load.load::<M>(overrides)?;
let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
args.persist_session.as_deref(),
Expand Down Expand Up @@ -145,7 +165,7 @@ fn info<M: llm::KnownModel + 'static>(args: &cli_args::Info) -> Result<()> {

fn prompt_tokens<M: llm::KnownModel + 'static>(args: &cli_args::PromptTokens) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let model = args.model_load.load::<M>()?;
let model = args.model_load.load::<M>(None)?;
let toks = match model.vocabulary().tokenize(&prompt, false) {
Ok(toks) => toks,
Err(e) => {
Expand Down Expand Up @@ -174,13 +194,14 @@ fn prompt_tokens<M: llm::KnownModel + 'static>(args: &cli_args::PromptTokens) ->

fn interactive<M: llm::KnownModel + 'static>(
args: &cli_args::Repl,
overrides: Option<M::Overrides>,
// If set to false, the session will be cloned after each inference
// to ensure that previous state is not carried over.
chat_mode: bool,
) -> Result<()> {
let prompt_file = args.prompt_file.contents();
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load::<M>()?;
let model = args.model_load.load::<M>(overrides)?;
let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
Expand Down
Loading

0 comments on commit dbcd407

Please sign in to comment.