Skip to content

Commit

Permalink
gpt-2 : fix allocr worst-case when n_parallel > prompt size
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Oct 20, 2023
1 parent f25addb commit 53590e3
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions examples/gpt-2/main-batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1032,31 +1032,22 @@ int main(int argc, char ** argv) {
// keep this buffer alive while evaluating the model
ggml_backend_buffer_t buf_compute;

// create a gpt2_batch
// we use this object to submit token data for decoding
const int n_parallel = params.n_parallel;
gpt2_batch batch = gpt2_batch_init(std::max(embd_inp.size(), (size_t)n_parallel), 0);

// evaluate the initial prompt
batch.n_tokens = embd_inp.size();
const int n_batch_max = std::max(embd_inp.size(), (size_t)n_parallel);

for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token[i] = embd_inp[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}

// gpt2_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
// create a gpt2_batch
// we use this object to submit token data for decoding
gpt2_batch batch = gpt2_batch_init(n_batch_max, 0);

// prepare required memory and allocate the compute buffer
struct ggml_allocr * allocr = NULL;
// allocate the compute buffer
{
// alignment required by the backend
// alignment required by the backend
size_t align = ggml_backend_get_alignment(model.backend);
allocr = ggml_allocr_new_measure(align);

batch.n_tokens = n_batch_max;

// create the worst case graph for memory usage estimation
struct ggml_cgraph * gf = gpt2_graph(model, allocr, batch);

Expand All @@ -1076,6 +1067,19 @@ int main(int argc, char ** argv) {

std::vector<float> logits;

// evaluate the initial prompt
batch.n_tokens = embd_inp.size();

for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token[i] = embd_inp[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}

// gpt2_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;

if (gpt2_decode(model, allocr, batch, params.n_threads, logits) != 0) {
printf("%s: gpt2_decode() failed\n", __func__);
return 1;
Expand Down

0 comments on commit 53590e3

Please sign in to comment.