Skip to content

Commit

Permalink
Add min-P sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Nov 11, 2023
1 parent 0261a6a commit 0d436d7
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 4 deletions.
36 changes: 36 additions & 0 deletions exllamav2/exllamav2_ext/cpp/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,42 @@ int top_p_cpu
return k;
}

int min_p_cpu
(
const int num_candidates,
float* temp_probs,
int* temp_indices,
float min_p
)
{
// TIME_START;

float top_prob = temp_probs[0];
for (int i = 1; i < num_candidates; i++)
if (temp_probs[i] > top_prob) top_prob = temp_probs[i];

float threshold = top_prob * min_p;

int i = 0;
int j = num_candidates - 1;

while (j >= i)
{
while (temp_probs[i] >= threshold && j >= i) i++;
if (temp_probs[j] >= threshold)
{
swap<float>(temp_probs[i], temp_probs[j]);
swap<int>(temp_indices[i], temp_indices[j]);
i++;
}
j--;
}

// TIME_STOP;

return i;
}

int typical_cpu
(
const int num_candidates,
Expand Down
8 changes: 8 additions & 0 deletions exllamav2/exllamav2_ext/cpp/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ int top_p_cpu
float top_p
);

int min_p_cpu
(
const int num_candidates,
float* temp_probs,
int* temp_indices,
float min_p
);

int typical_cpu
(
const int num_candidates,
Expand Down
7 changes: 7 additions & 0 deletions exllamav2/exllamav2_ext/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ void sample_basic
float temperature,
int top_k,
float top_p,
float min_p,
float typical,
float random,
torch::Tensor output_tokens, // shape [bsz, 1]
Expand Down Expand Up @@ -810,6 +811,12 @@ void sample_basic
normalize_cpu(num_candidates, temp_probs);
}

if (min_p > 0.0f && min_p < 1.0f)
{
num_candidates = min_p_cpu(num_candidates, temp_probs, temp_indices, min_p);
normalize_cpu(num_candidates, temp_probs);
}

if (typical > 0.0f && typical < 1.0f)
{
num_candidates = typical_cpu(num_candidates, temp_probs, temp_indices, typical);
Expand Down
10 changes: 6 additions & 4 deletions exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ class ExLlamaV2Sampler:

class Settings:

token_repetition_penalty = 1.15
token_repetition_range = -1
token_repetition_decay = 0

temperature = 0.9
top_k = 40
top_p = 0.9
min_p = 0
typical = 0

token_repetition_penalty = 1.15
token_repetition_range = -1
token_repetition_decay = 0

token_bias = None

filters = []
Expand Down Expand Up @@ -134,6 +135,7 @@ def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor,
settings.temperature,
settings.top_k,
settings.top_p,
settings.min_p,
settings.typical,
random,
output_tokens,
Expand Down

0 comments on commit 0d436d7

Please sign in to comment.