Skip to content

Commit

Permalink
fix rustformers#216 - split prompts for chat mode in CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jul 2, 2023
1 parent 6e2362b commit 23d2bf8
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 55 deletions.
103 changes: 77 additions & 26 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{fmt, ops::Deref, path::PathBuf, sync::Arc};
use std::{
fmt,
ops::Deref,
path::{Path, PathBuf},
sync::Arc,
};

use clap::{Parser, ValueEnum};
use color_eyre::eyre::{self, WrapErr};
Expand Down Expand Up @@ -39,7 +44,7 @@ pub enum Args {
/// Note that most, if not all, existing models are not trained for this
/// and do not support a long enough context window to be able to
/// have an extended conversation.
Chat(Box<Repl>),
Chat(Box<Chat>),

/// Quantize a GGML model to 4-bit.
Quantize(Box<Quantize>),
Expand Down Expand Up @@ -157,6 +162,51 @@ pub struct Repl {
pub generate: Generate,
}

#[derive(Parser, Debug)]
pub struct Chat {
#[command(flatten)]
pub model_load: ModelLoad,

/// The file to read the initial prompt/prelude from.
///
/// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the
/// first user prompt.
#[arg(long, short = 'f')]
pub prelude_prompt_file: PathBuf,

/// The per-message prompt to use.
///
/// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the
/// user's message.
#[arg(long, short = 'p')]
pub message_prompt: Option<String>,

/// The file to read the per-message prompt from.
///
/// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the
/// user's message.
#[arg(long, short = 'q')]
pub message_prompt_file: Option<PathBuf>,

#[command(flatten)]
pub generate: Generate,
}
impl Chat {
pub fn message_prompt(&self) -> eyre::Result<String> {
if self.message_prompt.is_some() && self.message_prompt_file.is_some() {
eyre::bail!("Cannot specify both --message-prompt and --message-prompt-file")
}

if let Some(message_prompt_file) = &self.message_prompt_file {
read_prompt_file(message_prompt_file)
} else if let Some(message_prompt) = &self.message_prompt {
Ok(message_prompt.clone())
} else {
eyre::bail!("Must specify either --message-prompt or --message-prompt-file")
}
}
}

#[derive(Parser, Debug)]
pub struct Generate {
/// Sets the number of threads to use
Expand Down Expand Up @@ -493,33 +543,34 @@ impl ModelLoad {
pub struct PromptFile {
/// A file to read the prompt from.
#[arg(long, short = 'f', default_value = None)]
pub prompt_file: Option<String>,
pub prompt_file: Option<PathBuf>,
}
impl PromptFile {
pub fn contents(&self) -> Option<String> {
match &self.prompt_file {
Some(path) => {
match std::fs::read_to_string(path) {
Ok(mut prompt) => {
// Strip off the last character if it's exactly newline. Also strip off a single
// carriage return if it's there. Since String must be valid UTF-8 it should be
// guaranteed that looking at the string as bytes here is safe: UTF-8 non-ASCII
// bytes will always the high bit set.
if matches!(prompt.as_bytes().last(), Some(b'\n')) {
prompt.pop();
}
if matches!(prompt.as_bytes().last(), Some(b'\r')) {
prompt.pop();
}
Some(prompt)
}
Err(err) => {
log::error!("Could not read prompt file at {path}. Error {err}");
std::process::exit(1);
}
}
}
pub fn contents(&self) -> eyre::Result<Option<String>> {
Ok(match &self.prompt_file {
Some(path) => Some(read_prompt_file(path)?),
_ => None,
})
}
}

pub fn read_prompt_file(path: &Path) -> eyre::Result<String> {
match std::fs::read_to_string(path) {
Ok(mut prompt) => {
// Strip off the last character if it's exactly newline. Also strip off a single
// carriage return if it's there. Since String must be valid UTF-8 it should be
// guaranteed that looking at the string as bytes here is safe: UTF-8 non-ASCII
// bytes will always the high bit set.
if matches!(prompt.as_bytes().last(), Some(b'\n')) {
prompt.pop();
}
if matches!(prompt.as_bytes().last(), Some(b'\r')) {
prompt.pop();
}
Ok(prompt)
}
Err(err) => {
eyre::bail!("Could not read prompt file at {path:?}; error: {err}");
}
}
}
Expand Down
73 changes: 44 additions & 29 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{

use clap::Parser;
use cli_args::Args;
use color_eyre::eyre::{Context, ContextCompat, Result};
use color_eyre::eyre::{bail, Context, ContextCompat, Result};
use llm::{InferenceError, InferenceFeedback, InferenceResponse};
use rustyline::{
error::ReadlineError,
Expand All @@ -31,14 +31,14 @@ fn main() -> Result<()> {
Args::Perplexity(args) => perplexity(&args),
Args::Info(args) => info(&args),
Args::PromptTokens(args) => prompt_tokens(&args),
Args::Repl(args) => interactive(&args, false),
Args::Chat(args) => interactive(&args, true),
Args::Repl(args) => repl(&args),
Args::Chat(args) => chat(&args),
Args::Quantize(args) => quantize(&args),
}
}

fn infer(args: &cli_args::Infer) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?;
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load(args.generate.use_gpu)?;

Expand Down Expand Up @@ -106,7 +106,7 @@ fn infer(args: &cli_args::Infer) -> Result<()> {
}

fn perplexity(args: &cli_args::Perplexity) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?;
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load(args.generate.use_gpu)?;
let (mut session, _) = snapshot::read_or_create_session(
Expand Down Expand Up @@ -183,7 +183,7 @@ fn info(args: &cli_args::Info) -> Result<()> {
}

fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?;
let model = args.model_load.load(false)?;
let toks = match model.tokenizer().tokenize(&prompt, false) {
Ok(toks) => toks,
Expand Down Expand Up @@ -223,37 +223,54 @@ fn force_newline_event_seq() -> KeyEvent {
KeyEvent(KeyCode::Enter, Modifiers::SHIFT)
}

fn repl(args: &cli_args::Repl) -> Result<()> {
interactive(
&args.generate,
&args.model_load,
args.prompt_file.contents()?.as_deref(),
None,
)
}

fn chat(args: &cli_args::Chat) -> Result<()> {
interactive(
&args.generate,
&args.model_load,
Some(std::fs::read_to_string(&args.prelude_prompt_file)?.as_str()),
Some(&args.message_prompt()?),
)
}

fn interactive(
args: &cli_args::Repl,
// If set to false, the session will be cloned after each inference
// to ensure that previous state is not carried over.
chat_mode: bool,
generate: &cli_args::Generate,
model_load: &cli_args::ModelLoad,
initial_prompt_template: Option<&str>,
message_prompt_template: Option<&str>,
) -> Result<()> {
let prompt_file = args.prompt_file.contents();
let inference_session_config = args.generate.inference_session_config();
let model = args.model_load.load(args.generate.use_gpu)?;
let inference_session_config = generate.inference_session_config();
let model = model_load.load(generate.use_gpu)?;

let (mut session, mut session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
args.generate.load_session.as_deref(),
generate.load_session.as_deref(),
inference_session_config,
);
let parameters = args.generate.inference_parameters(model.eot_token_id());
let parameters = generate.inference_parameters(model.eot_token_id());

let mut rng = generate.rng();

let mut rng = args.generate.rng();
let mut rl = rustyline::Editor::<LineContinuationValidator, DefaultHistory>::new()?;
rl.set_helper(Some(LineContinuationValidator));

rl.bind_sequence(force_newline_event_seq(), Cmd::Newline);

loop {
let readline = rl.readline(">> ");
match readline {
Ok(raw_line) => {
let line = raw_line.replace("\\\n", "\n");

let prompt = prompt_file
.as_deref()
let prompt = message_prompt_template
.or(initial_prompt_template)
.map(|pf| process_prompt(pf, &line))
.unwrap_or(line);

Expand All @@ -277,7 +294,7 @@ fn interactive(
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: session_loaded,
maximum_token_count: args.generate.num_predict,
maximum_token_count: generate.num_predict,
},
// EvaluateOuputRequest
&mut Default::default(),
Expand All @@ -298,11 +315,11 @@ fn interactive(
}

// Reload session in REPL mode
if !chat_mode {
if message_prompt_template.is_none() {
(session, session_loaded) = snapshot::read_or_create_session(
model.as_ref(),
None,
args.generate.load_session.as_deref(),
generate.load_session.as_deref(),
inference_session_config,
);
}
Expand All @@ -315,7 +332,6 @@ fn interactive(
}
}
}

Ok(())
}

Expand Down Expand Up @@ -382,8 +398,8 @@ fn quantize(args: &cli_args::Quantize) -> Result<()> {
fn load_prompt_file_with_prompt(
prompt_file: &cli_args::PromptFile,
prompt: Option<&str>,
) -> String {
if let Some(prompt_file) = prompt_file.contents() {
) -> Result<String> {
Ok(if let Some(prompt_file) = prompt_file.contents()? {
if let Some(prompt) = prompt {
process_prompt(&prompt_file, prompt)
} else {
Expand All @@ -392,9 +408,8 @@ fn load_prompt_file_with_prompt(
} else if let Some(prompt) = prompt {
prompt.to_owned()
} else {
log::error!("No prompt or prompt file was provided. See --help");
std::process::exit(1);
}
bail!("No prompt or prompt file was provided. See --help");
})
}

#[derive(Completer, Helper, Highlighter, Hinter, Debug, Clone, Copy)]
Expand Down

0 comments on commit 23d2bf8

Please sign in to comment.