Skip to content

Commit

Permalink
fix rustformers#350 - rewrite chat/repl logic for correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jul 5, 2023
1 parent 68cefa6 commit 8a48238
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 67 deletions.
10 changes: 5 additions & 5 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ pub struct Infer {
#[arg(long, default_value_t = false)]
pub hide_prompt: bool,

/// Loads a saved inference session from the given path, previously saved using
/// `--save-session`
#[arg(long, default_value = None)]
pub load_session: Option<PathBuf>,

/// Saves an inference session at the given path. The same session can then be
/// loaded from disk using `--load-session`.
///
Expand Down Expand Up @@ -246,11 +251,6 @@ pub struct Generate {
#[arg(long, default_value_t = 0.95)]
pub top_p: f32,

/// Loads a saved inference session from the given path, previously saved using
/// `--save-session`
#[arg(long, default_value = None)]
pub load_session: Option<PathBuf>,

/// Specifies the seed to use during sampling. Note that, depending on
/// hardware, the same seed may lead to different results on two separate
/// machines.
Expand Down
144 changes: 82 additions & 62 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
use clap::Parser;
use cli_args::Args;
use color_eyre::eyre::{bail, Context, ContextCompat, Result};
use llm::{InferenceError, InferenceFeedback, InferenceResponse};
use llm::{InferenceError, InferenceFeedback, InferenceResponse, InferenceSession};
use rustyline::{
error::ReadlineError,
history::DefaultHistory,
Expand Down Expand Up @@ -45,7 +45,7 @@ fn infer(args: &cli_args::Infer) -> Result<()> {
let (mut session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
args.persist_session.as_deref(),
args.generate.load_session.as_deref(),
args.load_session.as_deref(),
inference_session_config,
);
let parameters = args.generate.inference_parameters(model.eot_token_id());
Expand Down Expand Up @@ -109,12 +109,8 @@ fn perplexity(args: &cli_args::Perplexity) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?;
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load(args.generate.use_gpu)?;
let (mut session, _) = snapshot::read_or_create_session(
model.as_ref(),
None,
args.generate.load_session.as_deref(),
inference_session_config,
);
let (mut session, _) =
snapshot::read_or_create_session(model.as_ref(), None, None, inference_session_config);
let parameters = args.generate.inference_parameters(model.eot_token_id());

session.perplexity(
Expand Down Expand Up @@ -227,15 +223,17 @@ fn repl(args: &cli_args::Repl) -> Result<()> {
interactive(
&args.generate,
&args.model_load,
args.prompt_file.contents()?.as_deref(),
false,
None,
args.prompt_file.contents()?.as_deref(),
)
}

fn chat(args: &cli_args::Chat) -> Result<()> {
interactive(
&args.generate,
&args.model_load,
true,
Some(std::fs::read_to_string(&args.prelude_prompt_file)?.as_str()),
Some(&args.message_prompt()?),
)
Expand All @@ -244,94 +242,116 @@ fn chat(args: &cli_args::Chat) -> Result<()> {
fn interactive(
generate: &cli_args::Generate,
model_load: &cli_args::ModelLoad,
initial_prompt_template: Option<&str>,
chat_mode: bool,
mut initial_prompt_template: Option<&str>,
message_prompt_template: Option<&str>,
) -> Result<()> {
let inference_session_config = generate.inference_session_config();
let model = model_load.load(generate.use_gpu)?;

let (mut session, mut session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
generate.load_session.as_deref(),
inference_session_config,
);
let parameters = generate.inference_parameters(model.eot_token_id());
let recreate_session =
|| snapshot::read_or_create_session(model.as_ref(), None, None, inference_session_config).0;
let mut session = recreate_session();

let parameters = generate.inference_parameters(model.eot_token_id());
let mut rng = generate.rng();

fn session_ends_with_newline(session: &InferenceSession) -> bool {
session
.decoded_tokens()
.last()
.map(|t| *t == b'\n')
.unwrap_or(false)
}

let mut infer = |session: &mut InferenceSession, mut prompt: String| {
// Add a newline to the beginning of the prompt if the last character in the session is not a newline
if !session_ends_with_newline(session) {
prompt.insert(0, '\n');
}

let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None);
if let Err(InferenceError::ContextFull) = session.feed_prompt(
model.as_ref(),
&parameters,
&prompt,
// OutputRequest
&mut Default::default(),
|_| Ok::<_, Infallible>(InferenceFeedback::Continue),
) {
log::error!("Prompt exceeds context window length.")
};
sp.clear();

session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
)
};

let mut rl = rustyline::Editor::<LineContinuationValidator, DefaultHistory>::new()?;
rl.set_helper(Some(LineContinuationValidator));
rl.bind_sequence(force_newline_event_seq(), Cmd::Newline);

loop {
let readline = rl.readline(">> ");
match readline {
Ok(raw_line) => {
let line = raw_line.replace("\\\n", "\n");

let prompt = message_prompt_template
.or(initial_prompt_template)
// Use the initial prompt template for the first inference,
// and then switch to the message prompt template afterwards
let mut prompt = initial_prompt_template
.take()
.or(message_prompt_template)
.map(|pf| process_prompt(pf, &line))
.unwrap_or(line);

let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None);
if let Err(InferenceError::ContextFull) = session.feed_prompt(
model.as_ref(),
&parameters,
&prompt,
// OutputRequest
&mut Default::default(),
|_| Ok::<_, Infallible>(InferenceFeedback::Continue),
) {
log::error!("Prompt exceeds context window length.")
};
sp.clear();

let res = session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: session_loaded,
maximum_token_count: generate.num_predict,
},
// EvaluateOuputRequest
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
);
println!();
// Add a newline to the end of the prompt if it doesn't end with one in chat mode
if chat_mode && !prompt.ends_with('\n') {
prompt.push('\n');
}

if let Err(err) = infer(&mut session, prompt) {
log::error!("{err}");
break;
}

if let Err(InferenceError::ContextFull) = res {
log::error!("Reply exceeds context window length");
if !session_ends_with_newline(&session) {
println!();
}

// Reload session in REPL mode
if message_prompt_template.is_none() {
(session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
generate.load_session.as_deref(),
inference_session_config,
);
session = recreate_session();
}
}
Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => {
break;
}
Err(err) => {
log::error!("{err}");
break;
}
}
}

Ok(())
}

Expand Down
10 changes: 10 additions & 0 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,16 @@ impl InferenceSession {

Ok(session)
}

/// All tokens generated by this inference session
pub fn tokens(&self) -> &[TokenId] {
self.tokens.as_ref()
}

/// All decoded tokens generated by this inference session
pub fn decoded_tokens(&self) -> &[u8] {
self.decoded_tokens.as_ref()
}
}

fn get_newly_decoded_portion_huggingface(
Expand Down

0 comments on commit 8a48238

Please sign in to comment.