Skip to content

Commit

Permalink
refactor(ggml): unify context creation
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jul 16, 2023
1 parent 55b2dc3 commit 70d57dc
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 52 deletions.
104 changes: 60 additions & 44 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ pub struct Context {
/// with it if the underlying context has been deallocated.
inner: Arc<ContextInner>,

/// Memory mapping information
pub mmap: Option<Mmap>,

/// Backing buffer (in case we own it)
pub buffer: Option<Buffer>,
/// The storage for this context. This is stored so that the buffer can be dropped when the context is dropped.
storage: Option<ContextStorage>,

/// Whether the context can offload tensors to the GPU
pub can_offload: bool,
Expand All @@ -46,7 +43,6 @@ pub(crate) struct ContextInner {
// interface and its scratch buffer solution.
pub offloaded_tensors: Mutex<HashMap<String, Tensor>>,
}

impl ContextInner {
pub(crate) fn new(ptr: *mut ggml_sys::ggml_context) -> Arc<Self> {
Arc::new(Self {
Expand All @@ -56,60 +52,72 @@ impl ContextInner {
}
}

/// Controls how the context uses memory.
pub enum ContextStorage {
/// Use the provided buffer as memory.
Buffer(Buffer),
/// Use the provided memory mapped file as memory.
Mmap(Mmap),
/// Allocate `mem_size` bytes of memory.
Allocate {
/// The size, in bytes, of the memory in to allocate.
mem_size: usize,
},
}

impl Context {
/// Creates a new [Context] using the buffer provided as memory
pub fn init_buffer(buffer: Buffer) -> Self {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
/// Creates a new [Context] with the given storage..
pub fn new(storage: ContextStorage) -> Self {
let init_params = match &storage {
ContextStorage::Buffer(buffer) => sys::ggml_init_params {
mem_size: buffer.size(),
mem_buffer: buffer.data,
no_alloc: false,
})
},
ContextStorage::Mmap(mmap) => sys::ggml_init_params {
mem_size: mmap.len(),
mem_buffer: std::ptr::null_mut(),
// We are mmapping so ggml does not need to allocate any memory for us
no_alloc: true,
},
ContextStorage::Allocate { mem_size } => sys::ggml_init_params {
mem_size: *mem_size,
// Null here means we want ggml to own this memory.
mem_buffer: std::ptr::null_mut(),
// It doesn't make sense to `no_alloc` when passing in a `mem_size` in this mode.
no_alloc: false,
},
};

let raw = unsafe { sys::ggml_init(init_params) };
Self {
inner: ContextInner::new(raw),
mmap: None,
buffer: Some(buffer),
storage: Some(storage),
can_offload: false,
}
}

/// Creates a new [Context] with the memory mapped file provided
pub fn init_mmap(mmap: Mmap) -> Self {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
mem_size: mmap.len(),
mem_buffer: std::ptr::null_mut(),
no_alloc: true, // We are mmapping so ggml does not need to allocate any memory for us
})
};
/// Creates a new [Context] with the specified buffer.
/// The buffer will be used by GGML.
pub fn new_with_buffer(buffer: Buffer) -> Self {
Self::new(ContextStorage::Buffer(buffer))
}

Self {
inner: ContextInner::new(raw),
mmap: Some(mmap),
buffer: None,
can_offload: false,
}
/// Creates a new [Context] with the specified memory mapped file.
pub fn new_with_mmap(mmap: Mmap) -> Self {
Self::new(ContextStorage::Mmap(mmap))
}

/// Creates a new [Context] with the specified `mem_size` as a working area.
pub fn init(mem_size: usize, alloc: bool) -> Self {
let raw = unsafe {
sys::ggml_init(sys::ggml_init_params {
mem_size,
// Null here means we want ggml to own this memory.
mem_buffer: std::ptr::null_mut(),
no_alloc: !alloc,
})
};
/// Creates a new [Context] with the specified memory size.
/// The memory will be allocated by GGML.
pub fn new_with_allocate(mem_size: usize) -> Self {
Self::new(ContextStorage::Allocate { mem_size })
}

Self {
inner: ContextInner::new(raw),
mmap: None,
buffer: None,
can_offload: false,
}
/// Recreates this context using the same storage.
pub fn recreate(&mut self) {
// This is the only operation that can consume the `self.storage`, so we can unwrap here.
*self = Self::new(self.storage.take().unwrap());
}

/// If offloading is enabled, all tensors created by this context will be offloaded to the GPU
Expand Down Expand Up @@ -182,6 +190,14 @@ impl Context {
let raw = unsafe { sys::ggml_new_f32(self.as_ptr(), x) };
self.new_tensor_raw(raw)
}

/// Returns the mmap used by this [Context], if any.
pub fn mmap(&self) -> Option<&Mmap> {
match &self.storage {
Some(ContextStorage::Mmap(mmap)) => Some(mmap),
_ => None,
}
}
}
// Operations
impl Context {
Expand Down
2 changes: 1 addition & 1 deletion crates/ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub mod util;

pub mod accelerator;

pub use context::Context;
pub use context::{Context, ContextStorage};
pub use tensor::Tensor;

pub use ggml_sys as sys;
Expand Down
6 changes: 3 additions & 3 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl InferenceSession {
ggml::accelerator::set_scratch_size(config.n_batch * 1024 * 1024);
}

let session_ctx = Arc::new(ggml::Context::init(ctx_size, true));
let session_ctx = Arc::new(ggml::Context::new_with_allocate(ctx_size));

// Initialize key + value memory tensors
let n_mem = n_layer * n_ctx;
Expand All @@ -167,7 +167,7 @@ impl InferenceSession {
};

let eval = Buffer::new(buf_size);
let ctx0 = ggml::Context::init_buffer(eval);
let ctx0 = ggml::Context::new_with_buffer(eval);

// Set up Metal support
#[cfg(feature = "metal")]
Expand Down Expand Up @@ -216,7 +216,7 @@ impl InferenceSession {
F: FnOnce(BuildContext) -> (ComputationGraph, GraphOutputs),
{
// Build a graph
self.ctx0 = ggml::Context::init_buffer(self.ctx0.buffer.take().unwrap());
self.ctx0.recreate();
let ctx0 = &mut self.ctx0;
let mut embd = ctx0.new_tensor_1d(ggml::Type::I32, input_tokens.len());
ggml::set_tensor_name(&embd, "embd");
Expand Down
6 changes: 3 additions & 3 deletions crates/llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,10 @@ pub fn load<M: KnownModel>(
unsafe {
let mmap = Mmap::map(&file)?;
let file_size = mmap.len() as u64;
(Context::init_mmap(mmap), file_size)
(Context::new_with_mmap(mmap), file_size)
}
} else {
(Context::init(ctx_size, true), file.metadata()?.len())
(Context::new_with_allocate(ctx_size), file.metadata()?.len())
};

let tensors_len = tensors.len();
Expand Down Expand Up @@ -646,7 +646,7 @@ impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
&self.context,
&mut self.file,
&self.path,
self.context.mmap.as_ref(),
self.context.mmap(),
);

let mut tensor = main_context.get_tensor(info)?;
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lora.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl LoraAdapter {

// Create a temporary context for the patching operations
// TODO: test if GPU can be enabled (make it configurable)
let patch_context = ggml::Context::init(patch_context_size, true);
let patch_context = ggml::Context::new_with_allocate(patch_context_size);
let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path, None);

// Load the A and B tensors
Expand Down

0 comments on commit 70d57dc

Please sign in to comment.