From 5fa9bb28ce28690102e7eddd3f561a3582c55008 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Mon, 6 Nov 2023 01:34:14 -0700 Subject: [PATCH] Update to llm-samplers v0.0.7 --- Cargo.lock | 5 ++- Cargo.toml | 3 +- binaries/llm-cli/src/cli_args.rs | 9 ++++++ binaries/llm-test/src/inference.rs | 10 +++--- crates/llm-base/src/lib.rs | 2 +- crates/llm-base/src/samplers.rs | 52 +++++++++++++++++++----------- 6 files changed, 53 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 049d70df..e2d26e79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1374,9 +1374,8 @@ dependencies = [ [[package]] name = "llm-samplers" -version = "0.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7553f60d113c9cdc6a5402456a31cd9a273bef79f6f16d8a4f7b4bedf5f754b2" +version = "0.0.7" +source = "git+https://github.com/KerfuffleV2/llm-samplers?branch=feat-v0.0.7#8c72d0c2838471bfbe26394694b41054bd789549" dependencies = [ "anyhow", "num-traits", diff --git a/Cargo.toml b/Cargo.toml index ae5b22f7..2daf8d62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,8 @@ clap = { version = "4.1.8", features = ["derive"] } memmap2 = "0.5.10" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing = { version = "0.1", features = ["log"] } -llm-samplers = "=0.0.6" +llm-samplers = { git = "https://github.com/KerfuffleV2/llm-samplers", branch = "feat-v0.0.7" } +# llm-samplers = "=0.0.6" # Config for 'cargo dist' [workspace.metadata.dist] diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 21b4a897..5dc2b1e6 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -290,6 +290,15 @@ pub struct Generate { /// top_p - The probability for the top tokens are added until the result is greater or equal to P and at least min_keep tokens have been seen. /// p(0.95): The cumulative probability after which no more tokens are kept for sampling. /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// top_a (default: disabled) - This sampler prunes tokens that don't meet a threshold based on the most probable token. The formula is `a1 * pow(max_prob, a2)`. See https://github.com/BlinkDL/RWKV-LM#the-top-a-sampling-method for more information. + /// a1(0.0): Threshold scale. A reasonable value is 0.2. Setting either a1 or a2 to 0 disables the sampler. + /// a2(0.0): Threshold power. A reasonable value is 2. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// min_p (default: disabled) - This sampler prunes tokens that don't meet a certain percentage of the most probable token. For example if `p` is `0.05` then after `min_keep` is satisfied, other tokens must be at least 5% of the most probable token. See https://github.com/ggerganov/llama.cpp/issues/3483 for more information. + /// p(0.0): Probability threshold. 0.05 to 0.2 are good starting values to try. Setting this to 0 disables the sampler. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. #[arg(long = "sampler", short = 's', verbatim_doc_comment)] pub sampler_options: Vec, diff --git a/binaries/llm-test/src/inference.rs b/binaries/llm-test/src/inference.rs index a9ace889..3666167e 100644 --- a/binaries/llm-test/src/inference.rs +++ b/binaries/llm-test/src/inference.rs @@ -92,14 +92,14 @@ fn run_inference( // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` // at all #[derive(Debug, Default)] -struct DeterministicSampler(SampleGreedy); +struct DeterministicSampler(SampleGreedy); -impl Sampler for DeterministicSampler { +impl Sampler for DeterministicSampler { fn sample<'a>( &mut self, - res: &mut dyn HasSamplerResources, - logits: &'a mut Logits, - ) -> anyhow::Result<&'a mut Logits> { + res: &mut dyn HasSamplerResources, + logits: &'a mut Logits, + ) -> anyhow::Result<&'a mut Logits> { let mut flat_bias = Default::default(); // This might look a little weird, but it's necessary because the resource diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index e07c8852..f0a88a8a 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -60,7 +60,7 @@ pub struct InferenceParameters { /// This can be anything that implements [Sampler]. Refer to /// the `llm-samplers` documentation for possible samplers and suggested /// combinations: - pub sampler: Arc>>, + pub sampler: Arc>, } //Since Sampler implements Send and Sync, InferenceParameters should too. diff --git a/crates/llm-base/src/samplers.rs b/crates/llm-base/src/samplers.rs index 7a179f0b..f0b07b9e 100644 --- a/crates/llm-base/src/samplers.rs +++ b/crates/llm-base/src/samplers.rs @@ -59,7 +59,7 @@ pub enum SamplingError { /// to ensure a valid configuration. pub struct ConfiguredSamplers { /// A builder from the `llm-samplers` crate. - pub builder: SamplerChainBuilder, + pub builder: SamplerChainBuilder, /// Mirostat 1 is present. pub mirostat1: bool, /// Mirostat 2 is present. @@ -74,15 +74,17 @@ pub struct ConfiguredSamplers { /// We call a configuration of samplers that run in a certain order a "chain". /// Here is a description of the default chain `llm` uses: /// -/// 1. Repetition (present by default, multiple allowed) -/// 2. Frequency/Presence (optional, multiple allowed) -/// 3. Sequence Repetition (optional, multiple allowed) -/// 4. Top-K (present by default - incompatible with Mirostat) -/// 5. Tail Free (optional - incompatible with Mirostat) -/// 6. Locally Typical (optional - incompatible with Mirostat) -/// 7. Top-P (present by default - incompatible with Mirostat) -/// 8. Temperature (present by default) -/// 9. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution. +/// 1. Repetition (present by default, multiple allowed) +/// 2. Frequency/Presence (optional, multiple allowed) +/// 3. Sequence Repetition (optional, multiple allowed) +/// 4. Top-K (present by default - incompatible with Mirostat) +/// 5. Tail Free (optional - incompatible with Mirostat) +/// 6. Locally Typical (optional - incompatible with Mirostat) +/// 7. Top-P (present by default - incompatible with Mirostat) +/// 8. Top-A (optional - incompatible with Mirostat) +/// 9. Min-P (optional - incompatible with Mirostat) +/// 10. Temperature (present by default) +/// 11. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution. /// /// Samplers listed as "present by default" but incompatible with Mirostat will /// only be enabled by default if there is no Mirostat sampler enabled. @@ -142,6 +144,20 @@ impl Default for ConfiguredSamplers { Option::::None, ), ), + ( + "topa", + SamplerSlot::new_single( + || Box::new(SampleTopA::default().a1(0.0).a2(0.0)), + Option::::None, + ), + ), + ( + "minp", + SamplerSlot::new_single( + || Box::new(SampleMinP::default().p(0.0)), + Option::::None, + ), + ), ( "temperature", SamplerSlot::new_single( @@ -203,7 +219,7 @@ impl ConfiguredSamplers { ))? } else if (self.mirostat1 || self.mirostat2) && self.incompat_mirostat { Err(SamplerConfigurationError::SamplerCombinationError( - "Cannot enable top-p, top-k, locally typical or tail free samplers with Mirostat 1 or 2".to_string(), + "Cannot enable top-p, top-k, top-a, min-p, locally typical or tail free samplers with Mirostat 1 or 2".to_string(), ))? } Ok(()) @@ -245,7 +261,9 @@ impl FromStr for ConfiguredSamplers { .inspect(|(name, _slot)| match name.as_str() { "mirostat1" => result.mirostat1 = true, "mirostat2" => result.mirostat2 = true, - "topp" | "topk" | "locallytypical" | "tailfree" => result.incompat_mirostat = true, + "topa" | "minp" | "topp" | "topk" | "locallytypical" | "tailfree" => { + result.incompat_mirostat = true + } _ => (), }) .collect::>(); @@ -269,7 +287,7 @@ impl FromStr for ConfiguredSamplers { /// Sample a token. This convenience function handles building /// the sampler resources and logits objects the sampler needs. pub fn sample_token( - mut sampler: impl Sampler, + mut sampler: impl Sampler, rng: &mut impl rand::Rng, previous_tokens: &[TokenId], last_logits: impl IntoIterator, @@ -297,7 +315,7 @@ pub fn build_sampler( n_vocab: usize, bias: &[(TokenId, f32)], args: &[impl AsRef], -) -> Result>>, SamplerConfigurationError> { +) -> Result>, SamplerConfigurationError> { let mut samplers = SamplerChain::new(); if !bias.is_empty() { @@ -326,7 +344,7 @@ pub fn build_sampler( } /// Get the default sampler chain. -pub fn default_samplers() -> Arc>> { +pub fn default_samplers() -> Arc> { let mut result = ConfiguredSamplers::default(); result.ensure_default_slots(); Arc::new(Mutex::new(result.builder.into_chain())) @@ -349,8 +367,6 @@ impl<'pt, 'r> fmt::Debug for SamplerResources<'pt, 'r> { } impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> { - type TokenId = TokenId; - fn with_rng_mut( &mut self, fun: &mut dyn FnMut(&mut dyn rand::RngCore), @@ -359,7 +375,7 @@ impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> { Ok(()) } - fn with_last_tokens(&self, fun: &mut dyn FnMut(&[Self::TokenId])) -> Result<(), SamplerError> { + fn with_last_tokens(&self, fun: &mut dyn FnMut(&[TokenId])) -> Result<(), SamplerError> { fun(self.previous_tokens); Ok(()) }