Skip to content

Commit

Permalink
[REST] OpenAI compatible Rest API (mlc-ai#1107)
Browse files Browse the repository at this point in the history
* add presence and frequency penalty

* Added support for passing conversation history in /v1/chat/completions endpoint

* Added support for RestAPI parameters max_gen_len, n, and stop_str

* * add presence and frequency penalty to generation config
* refactor generation config

* Added documentation for parameters

* replace lib_path with model_lib_path in rest.py

* fixed black isort issues

* fix lib_path
  • Loading branch information
Kartik14 authored Oct 24, 2023
1 parent 9cb8e8e commit 9166edb
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 130 deletions.
232 changes: 151 additions & 81 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ class LLMChat {
this->ResetRuntimeStats();
}
output_ids_.clear();
appeared_token_ids_.clear();
appeared_token_freq_.clear();
output_message_.clear();
stop_triggered_ = false;
if (append_conversation) {
Expand All @@ -672,12 +672,8 @@ class LLMChat {
PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll,
String generation_config_str = "") {
// process generation settings
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
picojson::object generation_config =
this->LoadGenerationConfigFromString(generation_config_str);

std::vector<int32_t> prompt_tokens =
PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config);
Expand Down Expand Up @@ -729,21 +725,16 @@ class LLMChat {
return;
}

// process generation settings
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
picojson::object generation_config =
this->LoadGenerationConfigFromString(generation_config_str);

int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config);

auto tend = std::chrono::high_resolution_clock::now();

this->prefill_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->prefill_total_tokens += token_len;
this->ProcessNextToken(next_token);
this->ProcessNextToken(next_token, generation_config);
}

/*!
Expand All @@ -768,13 +759,8 @@ class LLMChat {
return;
}

// process generation settings
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
picojson::object generation_config =
this->LoadGenerationConfigFromString(generation_config_str);

std::vector<int32_t> prompt_tokens =
this->PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config);
Expand Down Expand Up @@ -803,17 +789,12 @@ class LLMChat {

this->prefill_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->prefill_total_tokens += token_len;
this->ProcessNextToken(next_token);
this->ProcessNextToken(next_token, generation_config);
}

void DecodeStep(String generation_config_str = "") {
// process generation settings
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
picojson::object generation_config =
this->LoadGenerationConfigFromString(generation_config_str);

ICHECK(!output_ids_.empty());
int32_t last_token = output_ids_.back();
Expand All @@ -830,7 +811,7 @@ class LLMChat {

this->decode_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->decode_total_tokens += 1;
this->ProcessNextToken(next_token);
this->ProcessNextToken(next_token, generation_config);
}

bool Stopped() { return stop_triggered_; }
Expand Down Expand Up @@ -931,61 +912,109 @@ class LLMChat {
picojson::object config;
config["temperature"] = picojson::value(this->temperature_);
config["repetition_penalty"] = picojson::value(this->repetition_penalty_);
config["presence_penalty"] = picojson::value(this->presence_penalty_);
config["frequency_penalty"] = picojson::value(this->frequency_penalty_);
config["top_p"] = picojson::value(this->top_p_);
config["mean_gen_len"] = picojson::value(this->mean_gen_len_);
config["max_gen_len"] = picojson::value(this->max_gen_len_);
config["shift_fill_factor"] = picojson::value(this->shift_fill_factor_);
config["conv_config"] = this->conversation_.SerializeToJSON();
return picojson::value(config);
}
/*!
* \brief Sample output token from logits on device
*/
int32_t SampleTokenFromLogits(NDArray logits_on_device,
picojson::object generation_config = picojson::object()) {
// prepare generation settings
// the generation_config will not override the original config
// since is only used for this generation
double gen_temperature;
double gen_repetition_penalty;
double gen_top_p;

picojson::object LoadGenerationConfigFromString(const std::string& generation_config_str) {
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
return generation_config;
}

void ReadGenerationConfig(picojson::object generation_config, double* gen_temperature,
NDArray* gen_temperature_arr, double* gen_repetition_penalty,
double* gen_presence_penalty, double* gen_frequency_penalty,
double* gen_top_p) {
if (generation_config.count("temperature")) {
CHECK(generation_config["temperature"].is<double>());
gen_temperature = generation_config["temperature"].get<double>();
if (gen_temperature != this->temperature_) {
this->temperature_ = gen_temperature;
float temperature_cast = static_cast<float>(gen_temperature);
this->temperature_arr_.CopyFromBytes(&temperature_cast, sizeof(float));
}
*gen_temperature = generation_config["temperature"].get<double>();

*gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_);
float temperature_cast = static_cast<float>(*gen_temperature);
gen_temperature_arr->CopyFromBytes(&temperature_cast, sizeof(float));
} else {
gen_temperature = this->temperature_;
*gen_temperature = this->temperature_;
*gen_temperature_arr = this->temperature_arr_;
}
if (generation_config.count("repetition_penalty")) {
CHECK(generation_config["repetition_penalty"].is<double>());
gen_repetition_penalty = generation_config["repetition_penalty"].get<double>();
CHECK(generation_config["repetition_penalty"].get<double>() > 0)
<< "Repetition penalty must be a positive number!";
*gen_repetition_penalty = generation_config["repetition_penalty"].get<double>();
} else {
*gen_repetition_penalty = this->repetition_penalty_;
}
if (generation_config.count("presence_penalty")) {
CHECK(generation_config["presence_penalty"].is<double>());
CHECK(abs(generation_config["presence_penalty"].get<double>()) <= 2)
<< "Presence penalty must be in the range -2 to 2!";
*gen_presence_penalty = generation_config["presence_penalty"].get<double>();
} else {
*gen_presence_penalty = this->presence_penalty_;
}
if (generation_config.count("frequency_penalty")) {
CHECK(generation_config["frequency_penalty"].is<double>());
CHECK(abs(generation_config["frequency_penalty"].get<double>()) <= 2)
<< "Frequency penalty must be in the range -2 to 2!";
*gen_frequency_penalty = generation_config["frequency_penalty"].get<double>();
} else {
gen_repetition_penalty = this->repetition_penalty_;
*gen_frequency_penalty = this->frequency_penalty_;
}
if (generation_config.count("top_p")) {
CHECK(generation_config["top_p"].is<double>());
gen_top_p = generation_config["top_p"].get<double>();
*gen_top_p = generation_config["top_p"].get<double>();
} else {
gen_top_p = this->top_p_;
*gen_top_p = this->top_p_;
}
}

/*!
* \brief Sample output token from logits on device
*/
int32_t SampleTokenFromLogits(NDArray logits_on_device,
picojson::object generation_config = picojson::object()) {
// prepare generation settings
// the generation_config will not override the original config
// since is only used for this generation
double gen_temperature;
double gen_repetition_penalty;
double gen_presence_penalty;
double gen_frequency_penalty;
double gen_top_p;
this->ReadGenerationConfig(generation_config, &gen_temperature, &this->temperature_arr_,
&gen_repetition_penalty, &gen_presence_penalty,
&gen_frequency_penalty, &gen_top_p);

// update logits
if (gen_repetition_penalty == 1.0f) {
if (gen_temperature < 1e-6f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
} else {
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
if (gen_presence_penalty != 0.0f || gen_frequency_penalty != 0.0f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
this->ApplyPresenceAndFrequencyPenaltyOnCPU(gen_presence_penalty, gen_presence_penalty);
if (gen_temperature >= 1e-6f) {
this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
}
} else {
} else if (gen_repetition_penalty != 1.0f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty);
if (gen_temperature >= 1e-6f) {
this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
}
} else {
if (gen_temperature < 1e-6f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
} else {
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
}
}

// perform sampling
Expand Down Expand Up @@ -1018,6 +1047,25 @@ class LLMChat {
gen_max_gen_len = this->max_gen_len_;
}

std::vector<std::string> gen_stop_strs;
gen_stop_strs.push_back(conversation_.stop_str);

if (generation_config.count("stop")) {
if (!generation_config["stop"].is<picojson::null>()) {
CHECK(generation_config["stop"].is<std::string>() ||
generation_config["stop"].is<picojson::array>());
if (generation_config["stop"].is<std::string>()) {
gen_stop_strs.push_back(generation_config["stop"].get<std::string>());
} else {
picojson::array gen_stop_strs_arr = generation_config["stop"].get<picojson::array>();
for (const picojson::value& v : gen_stop_strs_arr) {
CHECK(v.is<std::string>());
gen_stop_strs.push_back(v.get<std::string>());
}
}
}
}

ICHECK(!stop_triggered_) << "Cannot call process when it is stopped";

stop_triggered_ =
Expand All @@ -1026,27 +1074,35 @@ class LLMChat {

if (!stop_triggered_) {
output_ids_.push_back(next_token);
appeared_token_ids_.insert(next_token);
if (appeared_token_freq_.find(next_token) != appeared_token_freq_.end()) {
appeared_token_freq_[next_token] += 1;
} else {
appeared_token_freq_[next_token] = 1;
}
}

output_message_ = tokenizer_->Decode(output_ids_);

if (!conversation_.stop_str.empty()) {
size_t stop_pos = output_message_.rfind(conversation_.stop_str);
if (stop_pos != std::string::npos) {
stop_triggered_ = true;
if (ft_.support_backtracking_kv_) {
// back tracking, find the first set of token that is smaller
// than the length
size_t backoff = 0;
for (; (output_ids_.size() > 0) && (output_message_.length() > stop_pos); ++backoff) {
output_ids_.pop_back();
output_message_ = tokenizer_->Decode(output_ids_);
}
// resize kv to remove the context
ft_.fkvcache_array_popn_(kv_cache_, backoff);
total_seq_len_ -= backoff;
size_t stop_pos = std::string::npos;
for (const std::string& stop_str : gen_stop_strs) {
if (!stop_str.empty()) {
stop_pos = std::min(stop_pos, output_message_.rfind(stop_str));
}
}

if (stop_pos != std::string::npos) {
stop_triggered_ = true;
if (ft_.support_backtracking_kv_) {
// back tracking, find the first set of token that is smaller
// than the length
size_t backoff = 0;
for (; (output_ids_.size() > 0) && (output_message_.length() > stop_pos); ++backoff) {
output_ids_.pop_back();
output_message_ = tokenizer_->Decode(output_ids_);
}
// resize kv to remove the context
ft_.fkvcache_array_popn_(kv_cache_, backoff);
total_seq_len_ -= backoff;
}
}

Expand Down Expand Up @@ -1113,15 +1169,25 @@ class LLMChat {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
float* logits_raw_data = static_cast<float*>(logits_on_cpu_->data);
for (const int32_t& token_id : this->appeared_token_ids_) {
if (logits_raw_data[token_id] <= 0) {
logits_raw_data[token_id] *= repetition_penalty;
for (const auto& token_freq : this->appeared_token_freq_) {
if (logits_raw_data[token_freq.first] <= 0) {
logits_raw_data[token_freq.first] *= repetition_penalty;
} else { // logits > 0
logits_raw_data[token_id] /= repetition_penalty;
logits_raw_data[token_freq.first] /= repetition_penalty;
}
}
}

void ApplyPresenceAndFrequencyPenaltyOnCPU(float presence_penalty, float frequency_penalty) {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
float* logits_raw_data = static_cast<float*>(logits_on_cpu_->data);
for (const auto& token_freq : this->appeared_token_freq_) {
logits_raw_data[token_freq.first] -=
(token_freq.second * frequency_penalty + presence_penalty);
}
}

void ApplySoftmaxWithTemperatureOnCPU(float temperature) {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
Expand Down Expand Up @@ -1211,12 +1277,16 @@ class LLMChat {
NDArray temperature_arr_;
// repetition penalty
double repetition_penalty_{1.0};
// presence penalty
double presence_penalty_{0.0};
// frequency penalty
double frequency_penalty_{0.0};
// top_p
double top_p_{0.95};
// output ids till now (refresh after encoding step)
std::vector<int32_t> output_ids_;
// appeared token ids till now (refresh after encoding step)
std::unordered_set<int32_t> appeared_token_ids_;
// frequency of appeared token ids till now (refresh after encoding step)
std::unordered_map<int32_t, int64_t> appeared_token_freq_;
// output message till now (refresh after encoding step)
std::string output_message_;
// Whether encounter stop str
Expand Down
Loading

0 comments on commit 9166edb

Please sign in to comment.