Skip to content

Commit

Permalink
fix rustformers#298: don't send new bytes if invalid decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jun 29, 2023
1 parent 60d6168 commit 7e2f2bf
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,10 @@ impl InferenceSession {
let mut token = match model.tokenizer() {
crate::Tokenizer::Embedded(_) => model.tokenizer().token(tk as usize).to_vec(),
crate::Tokenizer::HuggingFace(_) => {
let mut previous_tokens = self.tokens.clone();
previous_tokens.push(tk);
let mut tokens = self.tokens.clone();
tokens.push(tk);

let all_tokens = model.tokenizer().decode(previous_tokens, true);
let splitted = all_tokens.split_at(self.decoded_tokens.len());

splitted.1.to_vec()
get_newly_decoded_portion_huggingface(model, tokens, &self.decoded_tokens)
}
};

Expand Down Expand Up @@ -363,12 +360,11 @@ impl InferenceSession {
crate::Tokenizer::Embedded(_) => {
model.tokenizer().token(next_token as usize).to_vec()
}
crate::Tokenizer::HuggingFace(_) => {
let all_tokens = model.tokenizer().decode(self.tokens.clone(), true);
let splitted = all_tokens.split_at(self.decoded_tokens.len());

splitted.1.to_vec()
}
crate::Tokenizer::HuggingFace(_) => get_newly_decoded_portion_huggingface(
model,
self.tokens.clone(),
&self.decoded_tokens,
),
};

self.decoded_tokens.append(&mut res.clone());
Expand Down Expand Up @@ -595,6 +591,22 @@ impl InferenceSession {
}
}

fn get_newly_decoded_portion_huggingface(
model: &dyn Model,
tokens: Vec<u32>,
decoded_tokens: &[u8],
) -> Vec<u8> {
let all_tokens = model.tokenizer().decode(tokens, true);
// The bytes here come from a lossily-decoded String, so we need to convert it back to a String
// to check if it ends with a replacement character.
let all_tokens = unsafe { String::from_utf8_unchecked(all_tokens) };
if all_tokens.ends_with('�') {
// Return an empty vector: no valid text was generated from this token.
return vec![];
}
all_tokens.as_bytes()[decoded_tokens.len()..].to_vec()
}

#[derive(Error, Debug)]
/// Errors encountered during the inference process.
pub enum InferenceError {
Expand Down

0 comments on commit 7e2f2bf

Please sign in to comment.