Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Check tensor dimensions when loading
Browse files Browse the repository at this point in the history
  • Loading branch information
danforbes authored Apr 27, 2023
1 parent c608b4b commit e19418c
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion llm-base/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,19 @@ pub fn load<M: Model>(
tensor_name: name.to_owned(),
})?;

let dims = ne.len();
if dims != info.n_dims {
return Err(LoadError::InvariantBroken {
path: self.path.clone(),
invariant: format!(
"the tensor {name} should have {} dimensions, not {dims}",
info.n_dims
),
});
}

let ctx = &self.context;
let mut tensor = match ne.len() {
let mut tensor = match dims {
1 => ctx.new_tensor_1d(info.element_type, ne[0]),
2 => ctx.new_tensor_2d(info.element_type, ne[0], ne[1]),
3 => ctx.new_tensor_3d(info.element_type, ne[0], ne[1], ne[2]),
Expand Down

0 comments on commit e19418c

Please sign in to comment.