Skip to content

Commit

Permalink
NeoX -> GPT-NeoX
Browse files Browse the repository at this point in the history
danforbes committed May 16, 2023
1 parent 25942f9 commit 1320e54
Showing 7 changed files with 52 additions and 52 deletions.
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -32,15 +32,15 @@ fn main() -> Result<()> {
Args::NeoX {
args,
no_parallel_residual,
} => handle_args::<llm::models::NeoX>(
} => handle_args::<llm::models::GptNeoX>(
args,
Some(llm::models::NeoXOverrides {
Some(llm::models::GptNeoXOverrides {
use_parallel_residual: !*no_parallel_residual,
}),
),
Args::RedPajama { args } => handle_args::<llm::models::NeoX>(
Args::RedPajama { args } => handle_args::<llm::models::GptNeoX>(
args,
Some(llm::models::NeoXOverrides {
Some(llm::models::GptNeoXOverrides {
use_parallel_residual: false,
}),
),
6 changes: 3 additions & 3 deletions crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ llm-llama = { path = "../models/llama", features = ["convert"], optional = true,
llm-gpt2 = { path = "../models/gpt2", optional = true, version = "0.1.1" }
llm-gptj = { path = "../models/gptj", optional = true, version = "0.1.1" }
llm-bloom = { path = "../models/bloom", optional = true, version = "0.1.1" }
llm-neox = { path = "../models/neox", optional = true, version = "0.1.1" }
llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.1.1" }

serde = { workspace = true }

@@ -26,9 +26,9 @@ spinoff = { workspace = true }
serde_json = { workspace = true }

[features]
default = ["llama", "gpt2", "gptj", "bloom", "neox"]
default = ["llama", "gpt2", "gptj", "bloom", "gptneox"]
llama = ["dep:llm-llama"]
gpt2 = ["dep:llm-gpt2"]
gptj = ["dep:llm-gptj"]
bloom = ["dep:llm-bloom"]
neox = ["dep:llm-neox"]
gptneox = ["dep:llm-gptneox"]
44 changes: 22 additions & 22 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
//! - [GPT-2](llm_gpt2)
//! - [GPT-J](llm_gptj)
//! - [LLaMA](llm_llama)
//! - [GPT-NeoX](llm_neox)
//! - [GPT-NeoX](llm_gptneox)
//!
//! At present, the only supported backend is [GGML](https://github.com/ggerganov/ggml), but this is expected to
//! change in the future.
@@ -92,10 +92,10 @@ pub mod models {
pub use llm_gpt2::{self as gpt2, Gpt2};
#[cfg(feature = "gptj")]
pub use llm_gptj::{self as gptj, GptJ};
#[cfg(feature = "gptneox")]
pub use llm_gptneox::{self as gptneox, GptNeoX, GptNeoXOverrides};
#[cfg(feature = "llama")]
pub use llm_llama::{self as llama, Llama};
#[cfg(feature = "neox")]
pub use llm_neox::{self as neox, NeoX, NeoXOverrides};
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
@@ -113,11 +113,11 @@ pub enum ModelArchitecture {
#[cfg(feature = "llama")]
/// [LLaMA](llm_llama)
Llama,
#[cfg(feature = "neox")]
/// [GPT-NeoX](llm_neox)
NeoX,
#[cfg(feature = "neox")]
/// RedPajama: [GPT-NeoX](llm_neox) with `use_parallel_residual` set to false
#[cfg(feature = "gptneox")]
/// [GPT-NeoX](llm_gptneox)
GptNeoX,
#[cfg(feature = "gptneox")]
/// RedPajama: [GPT-NeoX](llm_gptneox) with `use_parallel_residual` set to false
RedPajama,
}

@@ -132,9 +132,9 @@ impl ModelArchitecture {
Self::GptJ,
#[cfg(feature = "llama")]
Self::Llama,
#[cfg(feature = "neox")]
Self::NeoX,
#[cfg(feature = "neox")]
#[cfg(feature = "gptneox")]
Self::GptNeoX,
#[cfg(feature = "gptneox")]
Self::RedPajama,
];
}
@@ -175,9 +175,9 @@ impl FromStr for ModelArchitecture {
"gptj" => Ok(GptJ),
#[cfg(feature = "llama")]
"llama" => Ok(Llama),
#[cfg(feature = "neox")]
"gptneox" => Ok(NeoX),
#[cfg(feature = "neox")]
#[cfg(feature = "gptneox")]
"gptneox" => Ok(GptNeoX),
#[cfg(feature = "gptneox")]
"redpajama" => Ok(RedPajama),
m => Err(UnsupportedModelArchitecture(format!(
"{m} is not a supported model architecture"
@@ -199,9 +199,9 @@ impl Display for ModelArchitecture {
GptJ => write!(f, "GPT-J"),
#[cfg(feature = "llama")]
Llama => write!(f, "LLaMA"),
#[cfg(feature = "neox")]
NeoX => write!(f, "GPT-NeoX"),
#[cfg(feature = "neox")]
#[cfg(feature = "gptneox")]
GptNeoX => write!(f, "GPT-NeoX"),
#[cfg(feature = "gptneox")]
RedPajama => write!(f, "RedPajama"),
}
}
@@ -247,15 +247,15 @@ pub fn load_dynamic(
GptJ => load_model::<models::GptJ>(path, params, overrides, load_progress_callback)?,
#[cfg(feature = "llama")]
Llama => load_model::<models::Llama>(path, params, overrides, load_progress_callback)?,
#[cfg(feature = "neox")]
NeoX => load_model::<models::NeoX>(path, params, overrides, load_progress_callback)?,
#[cfg(feature = "neox")]
RedPajama => load_model::<models::NeoX>(
#[cfg(feature = "gptneox")]
GptNeoX => load_model::<models::GptNeoX>(path, params, overrides, load_progress_callback)?,
#[cfg(feature = "gptneox")]
RedPajama => load_model::<models::GptNeoX>(
path,
params,
{
let mut overrides = overrides.unwrap_or_default();
overrides.merge(models::NeoXOverrides {
overrides.merge(models::GptNeoXOverrides {
use_parallel_residual: false,
});
Some(overrides)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "llm-neox"
name = "llm-gptneox"
version = "0.1.1"
license = { workspace = true }
repository = { workspace = true }
28 changes: 14 additions & 14 deletions crates/models/neox/src/lib.rs → crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize};
///
/// # Safety
/// This implements [Send] and [Sync] as it is immutable after construction.
pub struct NeoX {
pub struct GptNeoX {
hyperparameters: Hyperparameters,
n_context_tokens: usize,

@@ -45,37 +45,37 @@ pub struct NeoX {
_context: ggml::Context,
}

unsafe impl Send for NeoX {}
unsafe impl Sync for NeoX {}
unsafe impl Send for GptNeoX {}
unsafe impl Sync for GptNeoX {}

#[derive(Serialize, Deserialize, Clone, Copy)]
/// Overrides for the GPT-NeoX model.
pub struct NeoXOverrides {
pub struct GptNeoXOverrides {
/// Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
/// speedup at large scales (e.g. 20B).
///
/// Defaults to `true`.
/// The RedPajama models use `false`.
pub use_parallel_residual: bool,
}
impl Default for NeoXOverrides {
impl Default for GptNeoXOverrides {
fn default() -> Self {
Self {
use_parallel_residual: true,
}
}
}
impl From<ModelDynamicOverrides> for NeoXOverrides {
impl From<ModelDynamicOverrides> for GptNeoXOverrides {
fn from(val: ModelDynamicOverrides) -> Self {
let mut overrides = NeoXOverrides::default();
let mut overrides = GptNeoXOverrides::default();
if let Some(v) = val.get("use_parallel_residual") {
overrides.use_parallel_residual = v;
}
overrides
}
}
impl From<NeoXOverrides> for ModelDynamicOverrides {
fn from(val: NeoXOverrides) -> Self {
impl From<GptNeoXOverrides> for ModelDynamicOverrides {
fn from(val: GptNeoXOverrides) -> Self {
let mut overrides = ModelDynamicOverrides::default();
overrides.insert(
"use_parallel_residual".to_string(),
@@ -85,9 +85,9 @@ impl From<NeoXOverrides> for ModelDynamicOverrides {
}
}

impl KnownModel for NeoX {
impl KnownModel for GptNeoX {
type Hyperparameters = Hyperparameters;
type Overrides = NeoXOverrides;
type Overrides = GptNeoXOverrides;

fn new<E: Error>(
hyperparameters: Hyperparameters,
@@ -153,7 +153,7 @@ impl KnownModel for NeoX {
hyperparameters.use_parallel_residual = overrides.use_parallel_residual;
}

Ok(NeoX {
Ok(GptNeoX {
hyperparameters,
n_context_tokens,
vocabulary,
@@ -548,7 +548,7 @@ struct Layer {
}

#[cfg(test)]
impl NeoX {
impl GptNeoX {
/// This does *not* construct a valid model. All of the tensors are entirely
/// empty. However, it can be used to determine if some code will compile.
fn new_empty() -> Self {
@@ -577,7 +577,7 @@ mod tests {

#[test]
fn can_share_model_between_threads() {
let model = Arc::new(NeoX::new_empty());
let model = Arc::new(GptNeoX::new_empty());

for _ in 0..4 {
let model = model.clone();

0 comments on commit 1320e54

Please sign in to comment.