Skip to content

Commit

Permalink
Fix clippy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
setzer22 committed Mar 17, 2023
1 parent c6b4a09 commit 34a5bea
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,11 @@ impl Model {
path: part_path,
});
}
} else {
if tensor.nelements() / n_parts != nelements {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}
} else if tensor.nelements() / n_parts != nelements {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}

if n_dims == 1 {
Expand All @@ -551,22 +549,18 @@ impl Model {
path: part_path,
});
}
} else {
if split_type == 0 {
if tensor.get_ne()[0] / n_parts != ne[0] || tensor.get_ne()[1] != ne[1] {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}
} else {
if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / n_parts != ne[1] {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}
} else if split_type == 0 {
if tensor.get_ne()[0] / n_parts != ne[0] || tensor.get_ne()[1] != ne[1] {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}
} else if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / n_parts != ne[1] {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}

let bpe = match ftype {
Expand Down Expand Up @@ -697,7 +691,7 @@ impl Model {

// determine the required inference memory per token:
let mut mem_per_token = 0;
let _ = self.evaluate(
self.evaluate(
params.n_threads,
0,
&[0, 1, 2, 3],
Expand All @@ -718,7 +712,7 @@ impl Model {
let mut embd = Vec::new();
while remaining_tokens > 0 {
// predict
if embd.len() > 0 {
if !embd.is_empty() {
self.evaluate(
params.n_threads,
n_past,
Expand Down Expand Up @@ -793,6 +787,7 @@ impl Model {
}
}

#[allow(clippy::too_many_arguments)]
pub fn sample_top_p_top_k(
&self,
vocab: &Vocabulary,
Expand Down Expand Up @@ -844,7 +839,7 @@ impl Model {
let mut probs: Vec<f64> = logits_id
.iter()
.copied()
.map(|(k, v)| (k - maxl).exp())
.map(|(k, _)| (k - maxl).exp())
.collect();
let sum: f64 = probs.iter().copied().sum();

Expand Down Expand Up @@ -877,6 +872,7 @@ impl Model {
logits_id[idx].1
}

#[allow(non_snake_case)]
pub fn evaluate(
&self,
n_threads: i32,
Expand Down

0 comments on commit 34a5bea

Please sign in to comment.