Skip to content

Commit

Permalink
Merge pull request rustformers#362 from steventrouble/wte2
Browse files Browse the repository at this point in the history
Update gpt2 to use wte if no lm_head
  • Loading branch information
philpax authored Jul 11, 2023
2 parents 7f13bb9 + c995ca8 commit cf6086c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
5 changes: 5 additions & 0 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ impl Context {

/// Creates a 1D view over `a`.
pub fn op_view_1d(&self, a: &Tensor, ne0: usize, offset: usize) -> Tensor {
#[cfg(debug_assertions)]
assert!(
offset < a.nbytes(),
"Cannot create tensor view with offset larger than tensor"
);
let tensor = unsafe {
sys::ggml_view_1d(self.ptr.as_ptr(), a.ptr.as_ptr(), usize_to_i64(ne0), offset)
};
Expand Down
14 changes: 10 additions & 4 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ pub struct Gpt2 {
// weighted positional encodings
wpe: Tensor,
// language model head
lm_head: Tensor,
//
// Optional: if not present, the `wte` tensor is used instead.
lm_head: Option<Tensor>,

// weights for the model
layers: Vec<Layer>,
Expand Down Expand Up @@ -59,7 +61,10 @@ impl KnownModel for Gpt2 {
let ln_f_b = tl.load("model/ln_f/b")?;
let wte = tl.load("model/wte")?;
let wpe = tl.load("model/wpe")?;
let lm_head = tl.load("model/lm_head")?;

// GPT-2's language model head is optional; if it is not present,
// the `wte` tensor is used instead.
let lm_head = tl.load("model/lm_head").ok();

let mut layers = Vec::new();
for i in 0..hyperparameters.n_layer {
Expand Down Expand Up @@ -102,7 +107,7 @@ impl KnownModel for Gpt2 {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
InferenceSession::new(
config,
self.hyperparameters.n_ctx,
self.context_size,
self.hyperparameters.n_layer,
self.hyperparameters.n_embd,
self.hyperparameters.n_vocab,
Expand Down Expand Up @@ -306,7 +311,8 @@ impl KnownModel for Gpt2 {

let embeddings_tensor: ggml::Tensor = input_layer.share();

input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer);
let head = self.lm_head.as_ref().unwrap_or(&self.wte);
input_layer = ctx0.op_mul_mat(head, &input_layer);

(
gf,
Expand Down
2 changes: 1 addition & 1 deletion crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl KnownModel for GptJ {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
InferenceSession::new(
config,
self.hyperparameters.n_ctx,
self.context_size,
self.hyperparameters.n_layer,
self.hyperparameters.n_embd,
self.hyperparameters.n_vocab,
Expand Down
2 changes: 1 addition & 1 deletion crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl KnownModel for GptNeoX {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
InferenceSession::new(
config,
self.hyperparameters.n_ctx,
self.context_size,
self.hyperparameters.n_layer,
self.hyperparameters.n_embd,
self.hyperparameters.n_vocab,
Expand Down

0 comments on commit cf6086c

Please sign in to comment.