Skip to content

Commit

Permalink
Fix opencl inference
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Jul 7, 2023
1 parent d933f1e commit 7ec3683
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 17 deletions.
5 changes: 0 additions & 5 deletions crates/ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,4 @@ pub fn accelerator_initialize(device: i32) {
let split = 1.0f32;
sys::cuda::ggml_cuda_set_tensor_split(&split as *const f32);
}

#[cfg(feature = "clblast")]
unsafe {
sys::opencl::ggml_cl_init();
}
}
10 changes: 0 additions & 10 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,6 @@ impl<'session> BuildContext<'session> {
pub fn get_scratch(&self, idx: usize) -> Option<&Buffer> {
Some(&self.scratch[idx])
}

pub fn enable_offloading(&self) {
let mut ctx0 = self.ctx0.borrow_mut();
ctx0.enable_offloading();
}

pub fn disable_offloading(&self) {
let mut ctx0 = self.ctx0.borrow_mut();
ctx0.disable_offloading();
}
}

unsafe impl Send for InferenceSession {}
Expand Down
12 changes: 10 additions & 2 deletions crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,16 @@ impl KnownModel for Llama {

// model-global weights
let wte = tl.load("tok_embeddings.weight")?;
let norm = tl.offload("norm.weight", Backend::Gpu)?;

let output = tl.offload("output.weight", Backend::Gpu)?;
let backend = if params.should_offload(0) {
Backend::Gpu
} else {
Backend::Cpu
};

let norm = tl.offload("norm.weight", backend)?;

let output = tl.offload("output.weight", backend)?;

let mut layers = Vec::new();

Expand Down Expand Up @@ -131,6 +138,7 @@ impl KnownModel for Llama {
let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let mut ctx0 = builder.ctx0.borrow_mut();
let embd = builder.embd;

let mut input_layer = ctx0.op_get_rows(&self.wte, embd);

// for big prompts, if BLAS is enabled, it is better to use only one thread
Expand Down

0 comments on commit 7ec3683

Please sign in to comment.