Skip to content

Commit

Permalink
Merge branch 'main' into dfo/model/mpt
Browse files Browse the repository at this point in the history
danforbes committed May 16, 2023
2 parents e797746 + 7468ee9 commit 2b03ccc
Showing 33 changed files with 1,611 additions and 328 deletions.
20 changes: 10 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 3 additions & 5 deletions binaries/generate-ggml-bindings/src/main.rs
Original file line number Diff line number Diff line change
@@ -5,17 +5,15 @@
use std::path::PathBuf;

fn main() {
const HEADER_PATH: &str = "crates/ggml/sys/ggml/include/ggml/ggml.h";

let bindings = bindgen::Builder::default()
.header(HEADER_PATH)
.header("crates/ggml/sys/bindings.h")
// Suppress some warnings
.raw_line("#![allow(non_upper_case_globals)]")
.raw_line("#![allow(non_camel_case_types)]")
.raw_line("#![allow(non_snake_case)]")
.raw_line("#![allow(unused)]")
// Do not generate code for ggml's includes (stdlib)
.allowlist_file(HEADER_PATH)
// Only generate code if it's from GGML
.allowlist_file("crates/ggml/.*")
.generate()
.expect("Unable to generate bindings");

10 changes: 5 additions & 5 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
@@ -471,13 +471,13 @@ pub enum FileType {
/// Float 32-bit.
F32,
}
impl From<FileType> for llm::FileType {
impl From<FileType> for llm::FileTypeFormat {
fn from(t: FileType) -> Self {
match t {
FileType::Q4_0 => llm::FileType::MostlyQ4_0,
FileType::Q4_1 => llm::FileType::MostlyQ4_1,
FileType::F16 => llm::FileType::MostlyF16,
FileType::F32 => llm::FileType::F32,
FileType::Q4_0 => llm::FileTypeFormat::MostlyQ4_0,
FileType::Q4_1 => llm::FileTypeFormat::MostlyQ4_1,
FileType::F16 => llm::FileTypeFormat::MostlyF16,
FileType::F32 => llm::FileTypeFormat::F32,
}
}
}
8 changes: 4 additions & 4 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -32,15 +32,15 @@ fn main() -> Result<()> {
Args::NeoX {
args,
no_parallel_residual,
} => handle_args::<llm::models::NeoX>(
} => handle_args::<llm::models::GptNeoX>(
args,
Some(llm::models::NeoXOverrides {
Some(llm::models::GptNeoXOverrides {
use_parallel_residual: !*no_parallel_residual,
}),
),
Args::RedPajama { args } => handle_args::<llm::models::NeoX>(
Args::RedPajama { args } => handle_args::<llm::models::GptNeoX>(
args,
Some(llm::models::NeoXOverrides {
Some(llm::models::GptNeoXOverrides {
use_parallel_residual: false,
}),
),
45 changes: 40 additions & 5 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
@@ -143,26 +143,47 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// In-place, scales `a` by the 1D tensor `b`.
/// Scales `a` by the 1D tensor `b`.
pub fn op_scale(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_scale(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// In-place, sets the elements above the diagonal to -INF.
/// In-place, scales `a` by the 1D tensor `b`.
pub fn op_scale_inplace(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor =
unsafe { sys::ggml_scale_inplace(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Sets the elements above the diagonal to -INF.
pub fn op_diag_mask_inf(&self, a: &Tensor, n_past: usize) -> Tensor {
let tensor = unsafe {
sys::ggml_diag_mask_inf(self.ptr.as_ptr(), a.ptr.as_ptr(), usize_to_i32(n_past))
};
self.new_tensor_raw(tensor)
}

/// In-place, applies the [Softmax function](https://en.wikipedia.org/wiki/Softmax_function) to `a`.
/// In-place, sets the elements above the diagonal to -INF.
pub fn op_diag_mask_inf_inplace(&self, a: &Tensor, n_past: usize) -> Tensor {
let tensor = unsafe {
sys::ggml_diag_mask_inf_inplace(self.ptr.as_ptr(), a.ptr.as_ptr(), usize_to_i32(n_past))
};
self.new_tensor_raw(tensor)
}

/// Applies the [Softmax function](https://en.wikipedia.org/wiki/Softmax_function) to `a`.
pub fn op_soft_max(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_soft_max(self.ptr.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// In-place, applies the [Softmax function](https://en.wikipedia.org/wiki/Softmax_function) to `a`.
pub fn op_soft_max_inplace(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_soft_max_inplace(self.ptr.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Creates a new tensor with result of mapping `fun` with `a`.
///
/// `cnt` is the number of `f32` elements to be mapped.
@@ -332,7 +353,7 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// In-place; applies ROtary Positional Encoding.
/// Applies ROtary Positional Encoding.
pub fn op_rope(&self, a: &Tensor, npast: usize, ndims: usize, mode: i32) -> Tensor {
let tensor = unsafe {
sys::ggml_rope(
@@ -346,6 +367,20 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// In-place; applies ROtary Positional Encoding.
pub fn op_rope_inplace(&self, a: &Tensor, npast: usize, ndims: usize, mode: i32) -> Tensor {
let tensor = unsafe {
sys::ggml_rope_inplace(
self.ptr.as_ptr(),
a.ptr.as_ptr(),
usize_to_i32(npast),
usize_to_i32(ndims),
mode,
)
};
self.new_tensor_raw(tensor)
}

/// Computes the specified graph. Must be run in order to evaluate the graph.
pub fn graph_compute(&self, graph: &mut ComputationGraph) {
unsafe {
@@ -380,7 +415,7 @@ impl Context {
}
}

/// TODO: something something
/// Attention with LInear BIases (Ref: <https://arxiv.org/pdf/2108.12409.pdf>)
pub fn op_alibi(&self, a: &Tensor, n_past: usize, n_head: usize) -> Tensor {
let tensor = unsafe {
sys::ggml_alibi(
42 changes: 18 additions & 24 deletions crates/ggml/src/format/loader.rs
Original file line number Diff line number Diff line change
@@ -20,9 +20,9 @@ pub enum LoadError<E: Error> {
#[error("invalid file magic number: {0}")]
/// The file magic number is invalid.
InvalidMagic(u32),
#[error("invalid ggml format: format={0:?} version={1}")]
#[error("invalid ggml format: format={0:?}")]
/// An unsupported format version was found.
InvalidFormatVersion(ContainerType, u32),
InvalidFormatVersion(ContainerType),
#[error("non-specific I/O error")]
/// A non-specific IO error.
Io(#[from] std::io::Error),
@@ -142,28 +142,20 @@ pub fn load<E: Error, R: BufRead + Seek>(
handler: &mut impl LoadHandler<E>,
) -> Result<(), LoadError<E>> {
// Verify magic
let container_type: ContainerType = match read_u32(reader)? {
crate::FILE_MAGIC_GGMF => ContainerType::Ggmf,
crate::FILE_MAGIC_GGJT => ContainerType::Ggjt,
crate::FILE_MAGIC_UNVERSIONED => ContainerType::Ggml,
crate::FILE_MAGIC_GGLA => ContainerType::Ggla,
magic => return Err(LoadError::InvalidMagic(magic)),
};
handler
.container_type(container_type)
.map_err(LoadError::ImplementationError)?;
let container_type = ContainerType::read(reader)?;

// Load format version
match container_type {
ContainerType::Ggmf | ContainerType::Ggjt | ContainerType::Ggla => {
let _version: u32 = match read_u32(reader)? {
crate::FORMAT_VERSION => crate::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion(container_type, version)),
};
}
ContainerType::Ggml => {}
ContainerType::Ggml
| ContainerType::Ggmf(1)
| ContainerType::Ggjt(1 | 2)
| ContainerType::Ggla(1) => {}
_ => return Err(LoadError::InvalidFormatVersion(container_type)),
}

handler
.container_type(container_type)
.map_err(LoadError::ImplementationError)?;

// Load hyper params
let hparams = handler
.read_hyperparameters(reader)
@@ -175,8 +167,8 @@ pub fn load<E: Error, R: BufRead + Seek>(
let len = read_u32(reader)?.try_into()?;
let token = read_bytes_with_len(reader, len)?;
let token_score = match container_type {
ContainerType::Ggmf | ContainerType::Ggjt => read_f32(reader)?,
ContainerType::Ggml | ContainerType::Ggla => {
ContainerType::Ggmf(_version) | ContainerType::Ggjt(_version) => read_f32(reader)?,
ContainerType::Ggml | ContainerType::Ggla(_) => {
// Legacy model, set empty score
0.
}
@@ -188,8 +180,10 @@ pub fn load<E: Error, R: BufRead + Seek>(

// Load tensor data
match container_type {
ContainerType::Ggmf | ContainerType::Ggml => load_weights(reader, handler, false),
ContainerType::Ggjt | ContainerType::Ggla => load_weights(reader, handler, true),
ContainerType::Ggmf(_) | ContainerType::Ggml => load_weights(reader, handler, false),
ContainerType::Ggjt(_version) | ContainerType::Ggla(_version) => {
load_weights(reader, handler, true)
}
}
}

7 changes: 3 additions & 4 deletions crates/ggml/src/format/saver.rs
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ use std::{
io::{Seek, Write},
};

use crate::{util, ElementType};
use crate::{util, ContainerType, ElementType};

#[derive(Debug, thiserror::Error)]
/// Errors that can occur while writing a model.
@@ -57,16 +57,15 @@ pub struct TensorSaveInfo {

/// Saves a model to the given writer.
///
/// Only GGJT is supported.
/// Only GGJT version 2 is supported.
pub fn save<E: Error, W: Write + Seek>(
writer: &mut W,
handler: &mut dyn SaveHandler<E>,
vocabulary: &[(Vec<u8>, f32)],
tensor_names: &[String],
) -> Result<(), SaveError<E>> {
// Write header and hyperparameters
util::write_u32(writer, crate::FILE_MAGIC_GGJT)?;
util::write_u32(writer, crate::FORMAT_VERSION)?;
ContainerType::Ggjt(2).write(writer)?;
handler
.write_hyperparameters(writer)
.map_err(SaveError::ImplementationError)?;
Loading

0 comments on commit 2b03ccc

Please sign in to comment.