Skip to content

Commit

Permalink
Add the new gemma models. (huggingface#2023)
Browse files Browse the repository at this point in the history
* Add the new gemma models.

* Revert the lightning changes.

* Support for the 1.1 models.
  • Loading branch information
LaurentMazare authored Apr 6, 2024
1 parent 9fd52b3 commit 33c9b66
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
35 changes: 28 additions & 7 deletions candle-examples/examples/gemma/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;

#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
}

struct TextGeneration {
model: Model,
device: Device,
Expand Down Expand Up @@ -165,6 +181,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,

/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -196,14 +216,15 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => match model_id.as_str() {
"7b-it" => "google/gemma-7b-it".to_string(),
"7b" => "google/gemma-7b".to_string(),
"2b-it" => "google/gemma-2b-it".to_string(),
"2b" => "google/gemma-2b".to_string(),
_ => model_id.to_string(),
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Which::Base2B => "google/gemma-2b".to_string(),
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
},
None => "google/gemma-2b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ fn default_max_position_embeddings() -> usize {
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
#[serde(alias = "hidden_activation")]
pub hidden_act: candle_nn::Activation,
pub hidden_size: usize,
pub intermediate_size: usize,
Expand Down

0 comments on commit 33c9b66

Please sign in to comment.