Skip to content

Commit

Permalink
Add delete tokens test and impl.
Browse files Browse the repository at this point in the history
Note that llama.json fails the tests, so it's likely it doesn't support it.
I may investigate further, though.
  • Loading branch information
steventrouble committed Jul 7, 2023
1 parent b2238c2 commit 2e35b46
Show file tree
Hide file tree
Showing 17 changed files with 221 additions and 19 deletions.
3 changes: 3 additions & 0 deletions binaries/llm-test/configs/bloom.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"input": "Rustformers is",
"output": 15
}
},
{
"Delete": {}
}
]
}
3 changes: 3 additions & 0 deletions binaries/llm-test/configs/gptj.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"input": "Rustformers is",
"output": 257
}
},
{
"Delete": {}
}
]
}
3 changes: 3 additions & 0 deletions binaries/llm-test/configs/gptneox.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"input": "Rustformers is",
"output": 247
}
},
{
"Delete": {}
}
]
}
3 changes: 3 additions & 0 deletions binaries/llm-test/configs/llama.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"input": "Rustformers is",
"output": 260
}
},
{
"Delete": {}
}
]
}
3 changes: 3 additions & 0 deletions binaries/llm-test/configs/mpt.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"input": "Rustformers is",
"output": 247
}
},
{
"Delete": {}
}
]
}
100 changes: 100 additions & 0 deletions binaries/llm-test/src/delete.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//! Tests the model's token manipulation APIs:
//!
//! * [llm::InferenceSession::feed_prompt()]
//!
//! See [crate::TestCase::Tokens].
use std::convert::Infallible;

use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest};
use serde::Serialize;

use crate::{TestCaseReport, TestCaseReportMeta};

/// Error tolerance for the float comparisons.
const TOLERANCE: f32 = 1e-7;

/// Tests that models can delete tokens without changing the model's behavior.
pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport {
let report = DeleteReport::default();
let mut session = model.start_session(Default::default());
let mut output = OutputRequest {
all_logits: Some(vec![]),
..Default::default()
};

// Feed some tokens
if let Err(err) = feed_prompt("The llama lived on the", &mut session, model, &mut output) {
return report.failure(&err.to_string());
}

// Add token and get the logits
if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) {
return report.failure(&err.to_string());
}
let Some(original_logits) = output.all_logits.clone() else {
return report.failure("Model did not return logits.");
};

// Delete, then re-add. Verify logits are the same.
if let Err(err) = session.delete_tokens(model, 1) {
return report.failure(&err.to_string());
}
if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) {
return report.failure(&err.to_string());
}
let Some(redone_logits) = output.all_logits.clone() else {
return report.failure("Second run of model did not return logits.");
};

// Compare the logits
for (idx, (&original, redone)) in original_logits.iter().zip(redone_logits).enumerate() {
if original > redone + TOLERANCE || original < redone - TOLERANCE {
return report.failure(&format!(
"Expected logits to be the same after delete, but differed at {idx}, \
expected {original}, but was {redone}."
));
}
}

log::info!("`can_delete` test passed (no expected output)!");
report.success()
}

fn feed_prompt(
prompt: &str,
session: &mut InferenceSession,
model: &impl Model,
output: &mut OutputRequest,
) -> Result<(), llm::InferenceError> {
session.feed_prompt(model, &Default::default(), prompt, output, |x| {
always_continue(x)
})
}

fn always_continue(_: &[u8]) -> Result<InferenceFeedback, Infallible> {
Ok(InferenceFeedback::Continue)
}

#[derive(Serialize, Default)]
pub struct DeleteReport {
output: usize,
}

impl DeleteReport {
fn failure(self, msg: &str) -> TestCaseReport {
TestCaseReport {
meta: TestCaseReportMeta::Error {
error: msg.to_owned(),
},
report: crate::TestCaseReportInner::Delete(self),
}
}

fn success(self) -> TestCaseReport {
TestCaseReport {
meta: TestCaseReportMeta::Success,
report: crate::TestCaseReportInner::Delete(self),
}
}
}
6 changes: 6 additions & 0 deletions binaries/llm-test/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Test runner for all LLMs.
mod common;
mod delete;
mod inference;
mod tokens;

Expand Down Expand Up @@ -128,6 +129,7 @@ enum TestCase {
input: String,
output: usize,
},
Delete {},
}

#[derive(Serialize)]
Expand Down Expand Up @@ -158,6 +160,7 @@ pub enum TestCaseReportInner {
inference_stats: Option<InferenceStats>,
},
Tokens(tokens::TokensReport),
Delete(delete::DeleteReport),
}

async fn test_model(
Expand Down Expand Up @@ -278,6 +281,9 @@ async fn test_model(
TestCase::Tokens { input, output } => {
test_case_reports.push(tokens::can_feed(&model, input, *output));
}
TestCase::Delete {} => {
test_case_reports.push(delete::can_delete(&model));
}
}
}
let first_error: Option<String> =
Expand Down
23 changes: 14 additions & 9 deletions binaries/llm-test/src/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,21 @@
use std::convert::Infallible;

use llm::{InferenceFeedback, Model, OutputRequest};
use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest};
use serde::Serialize;

use crate::{TestCaseReport, TestCaseReportMeta};

/// Tests that the model performs as expected when feeding tokens
pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) -> TestCaseReport {
let mut report = TokensReport::default();

let mut session = model.start_session(Default::default());
let mut output = OutputRequest {
all_logits: Some(vec![]),
..Default::default()
};

let feed_prompt = &mut |prompt: &str| {
session.feed_prompt(model, &Default::default(), prompt, &mut output, |x| {
always_continue(x)
})
};

if let Err(err) = feed_prompt(input) {
if let Err(err) = feed_prompt(input, &mut session, model, &mut output) {
return report.failure(&err.to_string());
};

Expand Down Expand Up @@ -62,9 +55,21 @@ pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize)
));
}

log::info!("`can_feed` test passed (no expected output)!");
report.success()
}

fn feed_prompt(
prompt: &str,
session: &mut InferenceSession,
model: &impl Model,
output: &mut OutputRequest,
) -> Result<(), llm::InferenceError> {
session.feed_prompt(model, &Default::default(), prompt, output, |x| {
always_continue(x)
})
}

fn always_continue(_: &[u8]) -> Result<InferenceFeedback, Infallible> {
Ok(InferenceFeedback::Continue)
}
Expand Down
43 changes: 43 additions & 0 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,37 @@ impl InferenceSession {
Ok(())
}

/// Removes `num` tokens from the end of the buffer. Roughly the inverse of `feed_prompt`.
pub fn delete_tokens(
&mut self,
model: &dyn Model,
num: usize,
) -> Result<Vec<TokenId>, DeleteError> {
if !model.supports_delete() {
return Err(DeleteError::UnsupportedArchitecture);
}

if num >= self.n_past {
return Err(DeleteError::NotEnoughTokens);
}

// Remove the tokens from self.tokens.
let token_start = self.n_past - num;
let deleted_tokens: Vec<_> = self.tokens.drain(token_start..).collect();

// Remove the corresponding chars from decoded
let mut decoded_start = self.decoded_tokens.len();
for id in &deleted_tokens {
decoded_start -= model.tokenizer().token(*id as usize).len();
}
self.decoded_tokens.truncate(decoded_start);

// Decrement the n_past tokens counter.
self.n_past -= num;

Ok(deleted_tokens)
}

/// Infer the next token for this session.
pub fn infer_next_token(
&mut self,
Expand Down Expand Up @@ -637,6 +668,18 @@ pub enum InferenceError {
UserCallback(Box<dyn std::error::Error>),
}

#[derive(Error, Debug)]
/// Errors encountered during the snapshot process.
pub enum DeleteError {
/// Tried deleting more tokens than were available
#[error("tried deleting more tokens than were available")]
NotEnoughTokens,

/// Model architecture does not support delete
#[error("model architecture does not support deletes")]
UnsupportedArchitecture,
}

#[derive(Error, Debug)]
/// Errors encountered during the snapshot process.
pub enum SnapshotError {
Expand Down
6 changes: 3 additions & 3 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ pub use ggml;
pub use ggml::Type as ElementType;

pub use inference_session::{
feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest,
InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot,
InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, SnapshotError,
feed_prompt_callback, DeleteError, GraphOutputs, InferenceError, InferenceFeedback,
InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig,
InferenceSnapshot, InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, SnapshotError,
};
pub use loader::{
load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic,
Expand Down
13 changes: 13 additions & 0 deletions crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ pub trait KnownModel: Send + Sync {

/// Get the list of regexes to use to determine if a tensor in this model should not be quantized.
fn skip_quantize_tensors() -> Vec<Regex>;

/// Returns whether the model supports deleting tokens.
fn supports_delete(&self) -> bool {
// Assume we can't delete unless otherwise specified
false
}
}

/// A type-erased model to allow for interacting with a model without knowing
Expand Down Expand Up @@ -118,6 +124,9 @@ pub trait Model: Send + Sync {

/// Get the end of text/end of string token ID. This value is defined by model implementers.
fn eot_token_id(&self) -> TokenId;

/// Returns whether the model supports deleting tokens.
fn supports_delete(&self) -> bool;
}
impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
Expand Down Expand Up @@ -149,6 +158,10 @@ impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
fn eot_token_id(&self) -> TokenId {
KnownModel::eot_token_id(self)
}

fn supports_delete(&self) -> bool {
KnownModel::supports_delete(self)
}
}

/// Implemented by model hyperparameters for interacting with hyperparameters
Expand Down
14 changes: 7 additions & 7 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ use std::{
// This is the "user-facing" API, and GGML may not always be our backend.
pub use llm_base::{
feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout,
quantize, samplers, ElementType, FileType, FileTypeFormat, FormatMagic, Hyperparameters,
InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse,
InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef,
InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model,
ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress,
Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer,
TokenizerSource,
quantize, samplers, DeleteError, ElementType, FileType, FileTypeFormat, FormatMagic,
Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest,
InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot,
InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress,
Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError,
QuantizeProgress, Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer,
TokenizationError, Tokenizer, TokenizerSource,
};

use serde::Serialize;
Expand Down
4 changes: 4 additions & 0 deletions crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ impl KnownModel for Bloom {
fn skip_quantize_tensors() -> Vec<Regex> {
vec![]
}

fn supports_delete(&self) -> bool {
true
}
}

/// BLOOM [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
Expand Down
4 changes: 4 additions & 0 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ impl KnownModel for GptJ {
fn skip_quantize_tensors() -> Vec<Regex> {
vec![]
}

fn supports_delete(&self) -> bool {
true
}
}

/// GPT-J [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
Expand Down
4 changes: 4 additions & 0 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ impl KnownModel for GptNeoX {
fn skip_quantize_tensors() -> Vec<Regex> {
vec![]
}

fn supports_delete(&self) -> bool {
true
}
}

/// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
Expand Down
Loading

0 comments on commit 2e35b46

Please sign in to comment.