Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature importance type in saved model file #3220

Merged
merged 11 commits into from
Jul 15, 2020
9 changes: 6 additions & 3 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ Booster <- R6::R6Class(
},

# Save model
save_model = function(filename, num_iteration = NULL) {
save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {

# Check if number of iteration is non existent
if (is.null(num_iteration)) {
Expand All @@ -437,6 +437,7 @@ Booster <- R6::R6Class(
, ret = NULL
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, lgb.c_str(filename)
)

Expand All @@ -445,7 +446,7 @@ Booster <- R6::R6Class(
},

# Save model to string
save_model_to_string = function(num_iteration = NULL) {
save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {

# Check if number of iteration is non existent
if (is.null(num_iteration)) {
Expand All @@ -457,12 +458,13 @@ Booster <- R6::R6Class(
"LGBM_BoosterSaveModelToString_R"
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
))

},

# Dump model in memory
dump_model = function(num_iteration = NULL) {
dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {

# Check if number of iteration is non existent
if (is.null(num_iteration)) {
Expand All @@ -474,6 +476,7 @@ Booster <- R6::R6Class(
"LGBM_BoosterDumpModel_R"
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
)

},
Expand Down
15 changes: 9 additions & 6 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,37 +632,40 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,

LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE filename,
LGBM_SE call_state) {
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_CHAR_PTR(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_CHAR_PTR(filename)));
R_API_END();
}

LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
LGBM_SE call_state) {
R_API_BEGIN();
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END();
}

LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
LGBM_SE call_state) {
R_API_BEGIN();
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END();
}
Expand Down Expand Up @@ -707,9 +710,9 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8},
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 6},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 6},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 7},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 7},
{NULL, NULL, 0}
};

Expand Down
3 changes: 3 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(
LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE filename,
LGBM_SE call_state
);
Expand All @@ -604,6 +605,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(
LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
Expand All @@ -620,6 +622,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterDumpModel_R(
LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
Expand Down
6 changes: 6 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,12 @@ Learning Control Parameters

- **Note**: can be used only in CLI version

- ``saved_feature_importance_type`` :raw-html:`<a id="saved_feature_importance_type" title="Permalink to this parameter" href="#saved_feature_importance_type">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int

- the feature importance type in the saved model file

- ``0``: count-based feature importance; ``1``: gain-based feature importance

- ``snapshot_freq`` :raw-html:`<a id="snapshot_freq" title="Permalink to this parameter" href="#snapshot_freq">&#x1F517;&#xFE0E;</a>`, default = ``-1``, type = int, aliases: ``save_period``

- frequency of saving model file snapshot
Expand Down
10 changes: 6 additions & 4 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Dump model to json format string
* \param start_iteration The model will be saved start from
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \param feature_importance_type type of feature importance, 0: count, 1: gain
* \return Json format string of model
*/
virtual std::string DumpModel(int start_iteration, int num_iteration) const = 0;
virtual std::string DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const = 0;

/*!
* \brief Translate model to if-else statement
Expand All @@ -199,19 +200,20 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Save model to file
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param feature_importance_type type of feature importance, 0: count, 1:gain
* \param filename Filename that want to save to
* \return true if succeeded
*/
virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const = 0;
virtual bool SaveModelToFile(int start_iteration, int num_iterations, int feature_importance_type, const char* filename) const = 0;

/*!
* \brief Save model to string
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type type of feature importance, 0: count, 1:gain
* \return Non-empty string if succeeded
*/
virtual std::string SaveModelToString(int start_iteration, int num_iterations) const = 0;
virtual std::string SaveModelToString(int start_iteration, int num_iterations, int feature_importance_type) const = 0;

/*!
* \brief Restore from a serialized string
Expand Down
6 changes: 6 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -996,19 +996,22 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle,
* \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be saved
* \param num_iteration Index of the iteration that should be saved, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: count, 1:gain
* \param filename The name of the file
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
const char* filename);

/*!
* \brief Save model to string.
* \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be saved
* \param num_iteration Index of the iteration that should be saved, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: count, 1:gain
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length
* \param[out] out_str String of model, should pre-allocate memory
Expand All @@ -1017,6 +1020,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len,
char* out_str);
Expand All @@ -1026,6 +1030,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
* \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be dumped
* \param num_iteration Index of the iteration that should be dumped, <= 0 means dump all
* \param feature_importance_type type of feature importance, 0: count, 1:gain
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length
* \param[out] out_str JSON format string of model, should pre-allocate memory
Expand All @@ -1034,6 +1039,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len,
char* out_str);
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,10 @@ struct Config {
// desc = **Note**: can be used only in CLI version
std::string output_model = "LightGBM_model.txt";

// desc = the feature importance type in the saved model file
// desc = ``0``: count-based feature importance; ``1``: gain-based feature importance
int saved_feature_importance_type = 0;

// [no-save]
// alias = save_period
// desc = frequency of saving model file snapshot
Expand Down
17 changes: 14 additions & 3 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,7 +2600,7 @@ def eval_valid(self, feval=None):
return [item for i in range_(1, self.__num_dataset)
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)]

def save_model(self, filename, num_iteration=None, start_iteration=0):
def save_model(self, filename, num_iteration=None, start_iteration=0, feature_importance_type=0):
"""Save Booster to file.

Parameters
Expand All @@ -2613,6 +2613,8 @@ def save_model(self, filename, num_iteration=None, start_iteration=0):
If <= 0, all iterations are saved.
start_iteration : int, optional (default=0)
Start index of the iteration that should be saved.
feature_importance_type : int, optional (default=0)
0: count-based; 1: gain-based

Returns
-------
Expand All @@ -2625,6 +2627,7 @@ def save_model(self, filename, num_iteration=None, start_iteration=0):
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(feature_importance_type),
c_str(filename)))
_dump_pandas_categorical(self.pandas_categorical, filename)
return self
Expand Down Expand Up @@ -2685,7 +2688,7 @@ def model_from_string(self, model_str, verbose=True):
self.pandas_categorical = _load_pandas_categorical(model_str=model_str)
return self

def model_to_string(self, num_iteration=None, start_iteration=0):
def model_to_string(self, num_iteration=None, start_iteration=0, feature_importance_type=0):
"""Save Booster to string.

Parameters
Expand All @@ -2696,6 +2699,8 @@ def model_to_string(self, num_iteration=None, start_iteration=0):
If <= 0, all iterations are saved.
start_iteration : int, optional (default=0)
Start index of the iteration that should be saved.
feature_importance_type : int, optional (default=0)
0: count-based; 1: gain-based

Returns
-------
Expand All @@ -2712,6 +2717,7 @@ def model_to_string(self, num_iteration=None, start_iteration=0):
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(feature_importance_type),
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
Expand All @@ -2724,14 +2730,15 @@ def model_to_string(self, num_iteration=None, start_iteration=0):
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(feature_importance_type),
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
ret = string_buffer.value.decode('utf-8')
ret += _dump_pandas_categorical(self.pandas_categorical)
return ret

def dump_model(self, num_iteration=None, start_iteration=0):
def dump_model(self, num_iteration=None, start_iteration=0, feature_importance_type=0):
"""Dump Booster to JSON format.

Parameters
Expand All @@ -2742,6 +2749,8 @@ def dump_model(self, num_iteration=None, start_iteration=0):
If <= 0, all iterations are dumped.
start_iteration : int, optional (default=0)
Start index of the iteration that should be dumped.
feature_importance_type : int, optional (default=0)
0: count-based; 1: gain-based

Returns
-------
Expand All @@ -2758,6 +2767,7 @@ def dump_model(self, num_iteration=None, start_iteration=0):
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(feature_importance_type),
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
Expand All @@ -2770,6 +2780,7 @@ def dump_model(self, num_iteration=None, start_iteration=0):
self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(feature_importance_type),
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
Expand Down
6 changes: 4 additions & 2 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ void Application::InitTrain() {
void Application::Train() {
Log::Info("Started training...");
boosting_->Train(config_.snapshot_freq, config_.output_model);
boosting_->SaveModelToFile(0, -1, config_.output_model.c_str());
boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
config_.output_model.c_str());
// convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
Expand Down Expand Up @@ -233,7 +234,8 @@ void Application::Predict() {
boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf);
boosting_->SaveModelToFile(0, -1, config_.output_model.c_str());
boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
config_.output_model.c_str());
Log::Info("Finished RefitTree");
} else {
// create predictor
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
if (snapshot_freq > 0
&& (iter + 1) % snapshot_freq == 0) {
std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
SaveModelToFile(0, -1, snapshot_out.c_str());
SaveModelToFile(0, -1, config_->saved_feature_importance_type, snapshot_out.c_str());
}
}
}
Expand Down
12 changes: 9 additions & 3 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,11 @@ class GBDT : public GBDTBase {
* \brief Dump model to json format string
* \param start_iteration The model will be saved start from
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \param feature_importance_type type of feature importance, 0: count, 1:gain
* \return Json format string of model
*/
std::string DumpModel(int start_iteration, int num_iteration) const override;
std::string DumpModel(int start_iteration, int num_iteration,
int feature_importance_type) const override;

/*!
* \brief Translate model to if-else statement
Expand All @@ -272,18 +274,22 @@ class GBDT : public GBDTBase {
* \brief Save model to file
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type type of feature importance, 0: count, 1:gain
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const override;
bool SaveModelToFile(int start_iteration, int num_iterations,
int feature_importance_type,
const char* filename) const override;

/*!
* \brief Save model to string
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type type of feature importance, 0: count, 1:gain
* \return Non-empty string if succeeded
*/
std::string SaveModelToString(int start_iteration, int num_iterations) const override;
std::string SaveModelToString(int start_iteration, int num_iterations, int feature_importance_type) const override;

/*!
* \brief Restore from a serialized buffer
Expand Down
Loading