Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream merge oct31 #41

Merged
merged 75 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
898db76
[API] Add GenerationConfig (#1024)
davidpissarra Oct 8, 2023
ad3a6b9
Fix two bugs in kv-cache backtrack loop (#856)
shenberg Oct 8, 2023
6e40c21
[Build] Added --pdb flag to build.py, drop into pdb on error (#1017)
Lunderberg Oct 8, 2023
bae37b3
[Android] Use `AlertDialog` instead of `Toast` (#1039)
cyx-6 Oct 8, 2023
b44f679
Add doc for ChatConfig, ConvConfig, GenerationConfig, BuildArgs (#1040)
CharlieFRuan Oct 9, 2023
3a9849a
[Android] Add Llama2 q4f16_0 (#1041)
spectrometerHBH Oct 9, 2023
bed9e60
[Docs] Model prebuilts tracking page revamp (#1000)
CharlieFRuan Oct 9, 2023
c02fdaf
Update compile_models.rst (#1038)
yongjer Oct 9, 2023
85001ed
Support for the Stable LM 3B model (#1008)
jeethu Oct 9, 2023
a032d40
[Docs] Iterate model prebuilts docs (#1043)
CharlieFRuan Oct 9, 2023
a58605f
Update README.md
junrushao Oct 9, 2023
bdd9d9b
[CPP] Separate common utils out from llm_chat.cc (#1044)
MasterJH5574 Oct 9, 2023
20131fb
Update README.md (#1045)
junrushao Oct 9, 2023
1e6fb11
add verbose stats to mlc-chat REST API (#1049)
denise-k Oct 11, 2023
b9179cf
[Transform] Apply split_rotary optimization on prefill (#1033)
Lunderberg Oct 12, 2023
98ebd28
[Docs] Add `mlc.ai/package` to `DEPENDENCY INSTALLATION` group (#1055)
LeshengJin Oct 12, 2023
bfaa5b9
Revert "[Transform] Apply split_rotary optimization on prefill (#1033…
MasterJH5574 Oct 12, 2023
ca8c11b
[BugFix] Set the right `max_sequence_length` for both Llama-1 and Lla…
sunggg Oct 13, 2023
edab9b5
[Doc] Use -U instead of --force-reinstall (#1062)
junrushao Oct 13, 2023
d854105
[Model] Initial batching support for Llama (#1048)
MasterJH5574 Oct 14, 2023
c2b8cbc
Fix Stable LM 3B build (#1061)
jeethu Oct 14, 2023
481cd92
[Core] Remove duplication in MODEL.get_model calls (#1054)
Lunderberg Oct 14, 2023
8184431
[ParamManager] Cleanup creation of quantization IRModule (#1053)
Lunderberg Oct 14, 2023
9010d48
Minor typo fix (#1064)
jeethu Oct 15, 2023
b0bfc88
Add links to Python API Reference (#1068)
junrushao Oct 15, 2023
204860b
[Fix] ChatModule incorrect temperature buffer shape (#1070)
MasterJH5574 Oct 15, 2023
d202077
[ParamManager] Added progress bar for get_item/set_item (#1063)
Lunderberg Oct 16, 2023
9872c48
[Python] Extract common device str parse function in ChatModule (#1074)
MasterJH5574 Oct 16, 2023
3aefd9f
[Bugfix] Compilation Error in q4f32_1 (#1078)
junrushao Oct 17, 2023
2625945
Establish `mlc_chat.compiler` (#1082)
junrushao Oct 19, 2023
56a8004
Update README.md for Multi-GPU (#1090)
junrushao Oct 19, 2023
b0373d1
Support lib_path override in C++. Improvements on docs and error mess…
rickzx Oct 19, 2023
830656f
StreamIterator (#1057)
varshith15 Oct 19, 2023
9bf5723
Update `benchmark.py` according to #1086 (#1091)
junrushao Oct 19, 2023
62d0c03
Disable Disco for q4f16_ft and q8f16_ft quantization (#1094)
LeshengJin Oct 20, 2023
cf39bf6
[Format] Apply isort and black for `python/` (#1097)
junrushao Oct 20, 2023
e9b85ce
More formatting (#1099)
junrushao Oct 21, 2023
03c641a
Enable Python Linter (#1098)
junrushao Oct 21, 2023
46d11e6
Add Basic Pylint and Mypy Tooling (#1100)
junrushao Oct 21, 2023
6159cc4
[CI] Add clang-format (#1103)
junrushao Oct 22, 2023
16dd2ae
[Slim-LM] Smart path finding for config and weight (#1088)
LeshengJin Oct 23, 2023
f57c9c9
[Transform] Provide IRModule transform for rewrite_attention (#1052)
Lunderberg Oct 23, 2023
e5927ce
[ParamManager] Use BundleModelParams for transform_dequantize (#1056)
Lunderberg Oct 23, 2023
7ae8c6d
[Slim-LM] Introduce HFLoad for loading Pytorch and SafeTensor weights…
LeshengJin Oct 23, 2023
5a7dcd8
[WINDOWS] reduce noise in windows build (#1115)
tqchen Oct 24, 2023
61179a0
Add CLI commands for compilation (#1109)
junrushao Oct 24, 2023
8ce7793
Auto updated submodule references
Oct 24, 2023
488017d
fix mismatched argument name (#1117)
Sing-Li Oct 24, 2023
206103b
[Docs] Add doc for max and mean gen len, shift factor; and buildArgs …
CharlieFRuan Oct 24, 2023
2aa6809
Revert "[ParamManager] Use BundleModelParams for transform_dequantize…
junrushao Oct 24, 2023
9cb8e8e
Remove inaccurate warning message (#1121)
junrushao Oct 24, 2023
9166edb
[REST] OpenAI compatible Rest API (#1107)
Kartik14 Oct 24, 2023
a4279e3
Add --opt flag parsing to CLI (#1123)
junrushao Oct 25, 2023
973f9fc
[ParamManager][Redo] Use BundleModelParams for transform_dequantize (…
Lunderberg Oct 25, 2023
24f795e
added details to windows installation (#1133)
goutham2688 Oct 27, 2023
2c492e5
Grammatical and Typographical improvements (#1139)
tmsagarofficial Oct 28, 2023
2ec0cc8
Minor enhancements to `ChatModule` (#1132)
YuchenJin Oct 28, 2023
27ac5ac
Updating tvm install docs (#1143)
David-Sharma Oct 29, 2023
2b6d832
Make the help info consistent with program name (#1137)
fennecJ Oct 29, 2023
878ae84
Support parameter packing (#1146)
junrushao Oct 29, 2023
c0c3a8d
[Slim-LM] Enable Group Quant (#1129)
zxybazh Oct 29, 2023
2193767
Enable Mypy and Pylint in mlc_chat Python Package (#1149)
junrushao Oct 29, 2023
0a25374
Migrate Compiler Passes (#1150)
junrushao Oct 30, 2023
1a79a53
Compile Model Preset without External `config.json` (#1151)
junrushao Oct 30, 2023
ba67835
Update attention layer (#1153)
junrushao Oct 30, 2023
fee2cb5
Add batched Llama model definition using vLLM paged attention (#1134)
masahi Oct 30, 2023
ece97b1
[Transform][Redo] Apply split_rotary optimization on prefill (#1125)
Lunderberg Oct 30, 2023
b190578
Apply rewrite for normal attention and MQA (#1138)
Lunderberg Oct 30, 2023
8ca0176
[Rest] Fix emoji handling in Rest API. (#1142)
YuchenJin Oct 30, 2023
3cf5605
[Utility] Check for isinstance(exc, Exception) before entering pdb (#…
Lunderberg Oct 30, 2023
0a9d6c7
[Utils] Remove conversion to numpy array in utils.save_params (#1083)
Lunderberg Oct 30, 2023
425a2cb
[Fix][REST] Use lowered-cased "app" (#1159)
junrushao Oct 30, 2023
ed10571
sync llama def
masahi Oct 30, 2023
bab2e50
Merge remote-tracking branch 'mlc-ai/main' into upstream-merge-oct31
masahi Oct 30, 2023
2252c66
fi
masahi Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[REST] OpenAI compatible Rest API (mlc-ai#1107)
* add presence and frequency penalty

* Added support for passing conversation history in /v1/chat/completions endpoint

* Added support for RestAPI parameters max_gen_len, n, and stop_str

* * add presence and frequency penalty to generation config
* refactor generation config

* Added documentation for parameters

* replace lib_path with model_lib_path in rest.py

* fixed black isort issues

* fix lib_path
  • Loading branch information
Kartik14 authored Oct 24, 2023
commit 9166edbf844bf039314e4453ff9f441c4738c8a6
232 changes: 151 additions & 81 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ class LLMChat {
this->ResetRuntimeStats();
}
output_ids_.clear();
appeared_token_ids_.clear();
appeared_token_freq_.clear();
output_message_.clear();
stop_triggered_ = false;
if (append_conversation) {
Expand All @@ -672,12 +672,8 @@ class LLMChat {
PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll,
String generation_config_str = "") {
// process generation settings
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
picojson::object generation_config =
this->LoadGenerationConfigFromString(generation_config_str);

std::vector<int32_t> prompt_tokens =
PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config);
Expand Down Expand Up @@ -729,21 +725,16 @@ class LLMChat {
return;
}

// process generation settings
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
picojson::object generation_config =
this->LoadGenerationConfigFromString(generation_config_str);

int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config);

auto tend = std::chrono::high_resolution_clock::now();

this->prefill_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->prefill_total_tokens += token_len;
this->ProcessNextToken(next_token);
this->ProcessNextToken(next_token, generation_config);
}

/*!
Expand All @@ -768,13 +759,8 @@ class LLMChat {
return;
}

// process generation settings
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
picojson::object generation_config =
this->LoadGenerationConfigFromString(generation_config_str);

std::vector<int32_t> prompt_tokens =
this->PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config);
Expand Down Expand Up @@ -803,17 +789,12 @@ class LLMChat {

this->prefill_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->prefill_total_tokens += token_len;
this->ProcessNextToken(next_token);
this->ProcessNextToken(next_token, generation_config);
}

void DecodeStep(String generation_config_str = "") {
// process generation settings
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
picojson::object generation_config =
this->LoadGenerationConfigFromString(generation_config_str);

ICHECK(!output_ids_.empty());
int32_t last_token = output_ids_.back();
Expand All @@ -830,7 +811,7 @@ class LLMChat {

this->decode_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->decode_total_tokens += 1;
this->ProcessNextToken(next_token);
this->ProcessNextToken(next_token, generation_config);
}

bool Stopped() { return stop_triggered_; }
Expand Down Expand Up @@ -931,61 +912,109 @@ class LLMChat {
picojson::object config;
config["temperature"] = picojson::value(this->temperature_);
config["repetition_penalty"] = picojson::value(this->repetition_penalty_);
config["presence_penalty"] = picojson::value(this->presence_penalty_);
config["frequency_penalty"] = picojson::value(this->frequency_penalty_);
config["top_p"] = picojson::value(this->top_p_);
config["mean_gen_len"] = picojson::value(this->mean_gen_len_);
config["max_gen_len"] = picojson::value(this->max_gen_len_);
config["shift_fill_factor"] = picojson::value(this->shift_fill_factor_);
config["conv_config"] = this->conversation_.SerializeToJSON();
return picojson::value(config);
}
/*!
* \brief Sample output token from logits on device
*/
int32_t SampleTokenFromLogits(NDArray logits_on_device,
picojson::object generation_config = picojson::object()) {
// prepare generation settings
// the generation_config will not override the original config
// since is only used for this generation
double gen_temperature;
double gen_repetition_penalty;
double gen_top_p;

picojson::object LoadGenerationConfigFromString(const std::string& generation_config_str) {
picojson::object generation_config = picojson::object();
if (!generation_config_str.empty()) {
picojson::value generation_config_json;
picojson::parse(generation_config_json, generation_config_str);
generation_config = generation_config_json.get<picojson::object>();
}
return generation_config;
}

void ReadGenerationConfig(picojson::object generation_config, double* gen_temperature,
NDArray* gen_temperature_arr, double* gen_repetition_penalty,
double* gen_presence_penalty, double* gen_frequency_penalty,
double* gen_top_p) {
if (generation_config.count("temperature")) {
CHECK(generation_config["temperature"].is<double>());
gen_temperature = generation_config["temperature"].get<double>();
if (gen_temperature != this->temperature_) {
this->temperature_ = gen_temperature;
float temperature_cast = static_cast<float>(gen_temperature);
this->temperature_arr_.CopyFromBytes(&temperature_cast, sizeof(float));
}
*gen_temperature = generation_config["temperature"].get<double>();

*gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_);
float temperature_cast = static_cast<float>(*gen_temperature);
gen_temperature_arr->CopyFromBytes(&temperature_cast, sizeof(float));
} else {
gen_temperature = this->temperature_;
*gen_temperature = this->temperature_;
*gen_temperature_arr = this->temperature_arr_;
}
if (generation_config.count("repetition_penalty")) {
CHECK(generation_config["repetition_penalty"].is<double>());
gen_repetition_penalty = generation_config["repetition_penalty"].get<double>();
CHECK(generation_config["repetition_penalty"].get<double>() > 0)
<< "Repetition penalty must be a positive number!";
*gen_repetition_penalty = generation_config["repetition_penalty"].get<double>();
} else {
*gen_repetition_penalty = this->repetition_penalty_;
}
if (generation_config.count("presence_penalty")) {
CHECK(generation_config["presence_penalty"].is<double>());
CHECK(abs(generation_config["presence_penalty"].get<double>()) <= 2)
<< "Presence penalty must be in the range -2 to 2!";
*gen_presence_penalty = generation_config["presence_penalty"].get<double>();
} else {
*gen_presence_penalty = this->presence_penalty_;
}
if (generation_config.count("frequency_penalty")) {
CHECK(generation_config["frequency_penalty"].is<double>());
CHECK(abs(generation_config["frequency_penalty"].get<double>()) <= 2)
<< "Frequency penalty must be in the range -2 to 2!";
*gen_frequency_penalty = generation_config["frequency_penalty"].get<double>();
} else {
gen_repetition_penalty = this->repetition_penalty_;
*gen_frequency_penalty = this->frequency_penalty_;
}
if (generation_config.count("top_p")) {
CHECK(generation_config["top_p"].is<double>());
gen_top_p = generation_config["top_p"].get<double>();
*gen_top_p = generation_config["top_p"].get<double>();
} else {
gen_top_p = this->top_p_;
*gen_top_p = this->top_p_;
}
}

/*!
* \brief Sample output token from logits on device
*/
int32_t SampleTokenFromLogits(NDArray logits_on_device,
picojson::object generation_config = picojson::object()) {
// prepare generation settings
// the generation_config will not override the original config
// since is only used for this generation
double gen_temperature;
double gen_repetition_penalty;
double gen_presence_penalty;
double gen_frequency_penalty;
double gen_top_p;
this->ReadGenerationConfig(generation_config, &gen_temperature, &this->temperature_arr_,
&gen_repetition_penalty, &gen_presence_penalty,
&gen_frequency_penalty, &gen_top_p);

// update logits
if (gen_repetition_penalty == 1.0f) {
if (gen_temperature < 1e-6f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
} else {
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
if (gen_presence_penalty != 0.0f || gen_frequency_penalty != 0.0f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
this->ApplyPresenceAndFrequencyPenaltyOnCPU(gen_presence_penalty, gen_presence_penalty);
if (gen_temperature >= 1e-6f) {
this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
}
} else {
} else if (gen_repetition_penalty != 1.0f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty);
if (gen_temperature >= 1e-6f) {
this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
}
} else {
if (gen_temperature < 1e-6f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
} else {
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
}
}

// perform sampling
Expand Down Expand Up @@ -1018,6 +1047,25 @@ class LLMChat {
gen_max_gen_len = this->max_gen_len_;
}

std::vector<std::string> gen_stop_strs;
gen_stop_strs.push_back(conversation_.stop_str);

if (generation_config.count("stop")) {
if (!generation_config["stop"].is<picojson::null>()) {
CHECK(generation_config["stop"].is<std::string>() ||
generation_config["stop"].is<picojson::array>());
if (generation_config["stop"].is<std::string>()) {
gen_stop_strs.push_back(generation_config["stop"].get<std::string>());
} else {
picojson::array gen_stop_strs_arr = generation_config["stop"].get<picojson::array>();
for (const picojson::value& v : gen_stop_strs_arr) {
CHECK(v.is<std::string>());
gen_stop_strs.push_back(v.get<std::string>());
}
}
}
}

ICHECK(!stop_triggered_) << "Cannot call process when it is stopped";

stop_triggered_ =
Expand All @@ -1026,27 +1074,35 @@ class LLMChat {

if (!stop_triggered_) {
output_ids_.push_back(next_token);
appeared_token_ids_.insert(next_token);
if (appeared_token_freq_.find(next_token) != appeared_token_freq_.end()) {
appeared_token_freq_[next_token] += 1;
} else {
appeared_token_freq_[next_token] = 1;
}
}

output_message_ = tokenizer_->Decode(output_ids_);

if (!conversation_.stop_str.empty()) {
size_t stop_pos = output_message_.rfind(conversation_.stop_str);
if (stop_pos != std::string::npos) {
stop_triggered_ = true;
if (ft_.support_backtracking_kv_) {
// back tracking, find the first set of token that is smaller
// than the length
size_t backoff = 0;
for (; (output_ids_.size() > 0) && (output_message_.length() > stop_pos); ++backoff) {
output_ids_.pop_back();
output_message_ = tokenizer_->Decode(output_ids_);
}
// resize kv to remove the context
ft_.fkvcache_array_popn_(kv_cache_, backoff);
total_seq_len_ -= backoff;
size_t stop_pos = std::string::npos;
for (const std::string& stop_str : gen_stop_strs) {
if (!stop_str.empty()) {
stop_pos = std::min(stop_pos, output_message_.rfind(stop_str));
}
}

if (stop_pos != std::string::npos) {
stop_triggered_ = true;
if (ft_.support_backtracking_kv_) {
// back tracking, find the first set of token that is smaller
// than the length
size_t backoff = 0;
for (; (output_ids_.size() > 0) && (output_message_.length() > stop_pos); ++backoff) {
output_ids_.pop_back();
output_message_ = tokenizer_->Decode(output_ids_);
}
// resize kv to remove the context
ft_.fkvcache_array_popn_(kv_cache_, backoff);
total_seq_len_ -= backoff;
}
}

Expand Down Expand Up @@ -1113,15 +1169,25 @@ class LLMChat {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
float* logits_raw_data = static_cast<float*>(logits_on_cpu_->data);
for (const int32_t& token_id : this->appeared_token_ids_) {
if (logits_raw_data[token_id] <= 0) {
logits_raw_data[token_id] *= repetition_penalty;
for (const auto& token_freq : this->appeared_token_freq_) {
if (logits_raw_data[token_freq.first] <= 0) {
logits_raw_data[token_freq.first] *= repetition_penalty;
} else { // logits > 0
logits_raw_data[token_id] /= repetition_penalty;
logits_raw_data[token_freq.first] /= repetition_penalty;
}
}
}

void ApplyPresenceAndFrequencyPenaltyOnCPU(float presence_penalty, float frequency_penalty) {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
float* logits_raw_data = static_cast<float*>(logits_on_cpu_->data);
for (const auto& token_freq : this->appeared_token_freq_) {
logits_raw_data[token_freq.first] -=
(token_freq.second * frequency_penalty + presence_penalty);
}
}

void ApplySoftmaxWithTemperatureOnCPU(float temperature) {
CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!";
CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
Expand Down Expand Up @@ -1211,12 +1277,16 @@ class LLMChat {
NDArray temperature_arr_;
// repetition penalty
double repetition_penalty_{1.0};
// presence penalty
double presence_penalty_{0.0};
// frequency penalty
double frequency_penalty_{0.0};
// top_p
double top_p_{0.95};
// output ids till now (refresh after encoding step)
std::vector<int32_t> output_ids_;
// appeared token ids till now (refresh after encoding step)
std::unordered_set<int32_t> appeared_token_ids_;
// frequency of appeared token ids till now (refresh after encoding step)
std::unordered_map<int32_t, int64_t> appeared_token_freq_;
// output message till now (refresh after encoding step)
std::string output_message_;
// Whether encounter stop str
Expand Down
Loading