diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index b10e9b25..db78b3a0 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -18,7 +18,7 @@ pub struct Llama { params: ModelParameters, hyperparameters: Hyperparameters, tokenizer: Tokenizer, - _version: LlamaModelVersion, + _version: LlamaModelType, // model-global weights // weighted token embeddings wte: ggml::Tensor, @@ -96,12 +96,12 @@ impl KnownModel for Llama { // TODO: read from file let mut version = match hyperparameters.n_layer { - 26 => LlamaModelVersion::Model3b, - 32 => LlamaModelVersion::Model7b, - 40 => LlamaModelVersion::Model13b, - 60 => LlamaModelVersion::Model30b, - 80 => LlamaModelVersion::Model65b, - _ => LlamaModelVersion::Model7b, // anything < 32 + 26 => LlamaModelType::Model3b, + 32 => LlamaModelType::Model7b, + 40 => LlamaModelType::Model13b, + 60 => LlamaModelType::Model30b, + 80 => LlamaModelType::Model65b, + _ => LlamaModelType::Model7b, // anything < 32 }; // TODO: temporary fix for 70B models if let Some(n_gqa) = params.n_gqa { @@ -112,7 +112,7 @@ impl KnownModel for Llama { "assuming 70B Llama2 model based on GQA == 8" ); hyperparameters.n_head_kv = hyperparameters.n_head / n_gqa; - version = LlamaModelVersion::Model70b; + version = LlamaModelType::Model70b; } } @@ -488,7 +488,7 @@ struct Layer { } /// Available Llama models -enum LlamaModelVersion { +enum LlamaModelType { Model3b, Model7b, Model13b,