Skip to content

Commit

Permalink
fix(llama): load parts based on dir, not n_embd
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Mar 18, 2023
1 parent 3ee9b0b commit 47e361d
Showing 1 changed file with 50 additions and 38 deletions.
88 changes: 50 additions & 38 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,6 @@ impl Display for OutputToken<'_> {
}
}

fn llama_n_parts(size: i32) -> i32 {
match size {
4096 => 1,
5120 => 2,
6656 => 4,
8192 => 8,
_ => unreachable!("Invalid size for N_PARTS"),
}
}

/// Each variant represents a step within the process of loading the model.
/// These can be used to report progress to the user.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
Expand Down Expand Up @@ -215,6 +205,8 @@ pub enum LoadError {
source: std::io::Error,
path: PathBuf,
},
#[error("no parent path for {path:?}")]
NoParentPath { path: PathBuf },
#[error("unable to read exactly {bytes} bytes")]
ReadExactFailed {
source: std::io::Error,
Expand Down Expand Up @@ -281,13 +273,15 @@ impl Model {
use std::fs::File;
use std::io::BufReader;

let path = path.as_ref();
let main_path = path.as_ref();

let mut reader =
BufReader::new(File::open(path).map_err(|e| LoadError::OpenFileFailed {
source: e,
path: path.to_owned(),
})?);
BufReader::new(
File::open(main_path).map_err(|e| LoadError::OpenFileFailed {
source: e,
path: main_path.to_owned(),
})?,
);

/// Helper function. Reads an int from the buffer and returns it.
fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Expand Down Expand Up @@ -319,7 +313,7 @@ impl Model {
let magic = read_i32(&mut reader)?;
if magic != 0x67676d6c {
return Err(LoadError::InvalidMagic {
path: path.to_owned(),
path: main_path.to_owned(),
});
}
}
Expand All @@ -343,7 +337,6 @@ impl Model {

let n_ff =
((2 * (4 * hparams.n_embd) / 3 + hparams.n_mult - 1) / hparams.n_mult) * hparams.n_mult;
let n_parts = llama_n_parts(hparams.n_embd);

load_progress_callback(LoadProgress::HyperparametersLoaded(&hparams));

Expand Down Expand Up @@ -490,23 +483,38 @@ impl Model {
let file_offset = reader.stream_position()?;
drop(reader);

for i in 0..n_parts {
let part_id = i;
let paths = {
let main_filename = main_path.file_name().and_then(|p| p.to_str());

let mut paths: Vec<PathBuf> =
std::fs::read_dir(main_path.parent().ok_or_else(|| LoadError::NoParentPath {
path: main_path.to_owned(),
})?)?
.filter_map(Result::ok)
.map(|de| de.path())
.filter(|p| {
p.file_name()
.and_then(|p| p.to_str())
.zip(main_filename)
.map(|(part_filename, main_filename)| {
part_filename.starts_with(main_filename)
})
.unwrap_or(false)
})
.collect();
paths.sort();
paths
};

let part_path = if i > 0 {
let mut path = path.to_owned();
let mut filename = path.components().last().unwrap().as_os_str().to_owned();
filename.push(&format!(".{i}"));
path.pop();
path.join(filename)
} else {
path.to_path_buf()
};
let n_parts = paths.len();

for (i, part_path) in paths.into_iter().enumerate() {
let part_id = i;

load_progress_callback(LoadProgress::PartLoading {
file: &part_path,
current_part: (i + 1).try_into()?,
total_parts: n_parts.try_into()?,
current_part: i + 1,
total_parts: n_parts,
});

let mut part_reader = BufReader::new(File::open(&part_path)?);
Expand Down Expand Up @@ -585,7 +593,7 @@ impl Model {
path: part_path,
});
}
} else if tensor.nelements() / n_parts != nelements {
} else if tensor.nelements() / i32::try_from(n_parts)? != nelements {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
Expand All @@ -600,13 +608,17 @@ impl Model {
});
}
} else if split_type == 0 {
if tensor.get_ne()[0] / n_parts != ne[0] || tensor.get_ne()[1] != ne[1] {
if tensor.get_ne()[0] / i32::try_from(n_parts)? != ne[0]
|| tensor.get_ne()[1] != ne[1]
{
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}
} else if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / n_parts != ne[1] {
} else if tensor.get_ne()[0] != ne[0]
|| tensor.get_ne()[1] / i32::try_from(n_parts)? != ne[1]
{
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
Expand Down Expand Up @@ -657,7 +669,7 @@ impl Model {
total_size += tensor.nbytes();
} else {
if (nelements as usize * bpe) / ggml::blck_size(tensor.get_type()) as usize
!= tensor.nbytes() / n_parts as usize
!= tensor.nbytes() / n_parts
{
return Err(LoadError::TensorWrongSize {
tensor_name,
Expand All @@ -676,15 +688,15 @@ impl Model {
for i1 in 0..ne[1] {
let offset_row = i1 as usize * row_size;
let offset = offset_row
+ ((part_id * np0) as usize
+ ((part_id * np0 as usize)
/ ggml::blck_size(tensor.get_type()) as usize)
* ggml::type_size(tensor.get_type());
// SAFETY: yolo, same as original code
unsafe {
let ptr = tensor.data().add(offset);
let slice = std::slice::from_raw_parts_mut(
ptr as *mut u8,
row_size / n_parts as usize,
row_size / n_parts,
);
part_reader.read_exact(slice)?;
}
Expand All @@ -696,7 +708,7 @@ impl Model {
* ggml::type_size(tensor.get_type());

for i1 in 0..ne[1] {
let offset_row = (i1 + part_id * np1) as usize * row_size;
let offset_row = (i1 as usize + part_id * np1 as usize) * row_size;
// SAFETY: yolo, same as original code
unsafe {
let ptr = tensor.data().add(offset_row);
Expand All @@ -707,7 +719,7 @@ impl Model {
}
}

total_size += tensor.nbytes() / n_parts as usize
total_size += tensor.nbytes() / n_parts;
}

n_tensors += 1;
Expand Down

0 comments on commit 47e361d

Please sign in to comment.