diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 3a11f0216bc7..33ddaa67da8a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -927,6 +927,12 @@ def create_valid(self, data, label=None, weight=None, group=None, ret._set_predictor(self._predictor) return ret + def _update_params(self, params): + if not self.params: + self.params = params + else: + self.params.update(params) + def construct(self): """ Lazy init diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 102eadbe9143..2d1bc79a001e 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -92,6 +92,7 @@ def train(params, train_set, num_boost_round=100, if not isinstance(train_set, Dataset): raise TypeError("Traninig only accepts Dataset object") + train_set._update_params(params) train_set._set_predictor(predictor) train_set.set_feature_name(feature_name) train_set.set_categorical_feature(categorical_feature) @@ -120,7 +121,8 @@ def train(params, train_set, num_boost_round=100, name_valid_sets.append(valid_names[i]) else: name_valid_sets.append('valid_'+str(i)) - + for valid_data in valid_sets: + valid_data._update_params(params) """process callbacks""" if callbacks is None: callbacks = set() @@ -332,7 +334,7 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False, predictor = init_model._to_predictor() else: predictor = None - + train_set._update_params(params) train_set._set_predictor(predictor) train_set.set_feature_name(feature_name) train_set.set_categorical_feature(categorical_feature) diff --git a/src/io/config.cpp b/src/io/config.cpp index 7044b9ab2398..35d5a16eb365 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -273,12 +273,12 @@ void TreeConfig::Set(const std::unordered_map& params) GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf); CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0); GetDouble(params, "lambda_l1", &lambda_l1); - CHECK(lambda_l1 >= 0.0f) - GetDouble(params, "lambda_l2", &lambda_l2); - CHECK(lambda_l2 >= 0.0f) - GetDouble(params, "min_gain_to_split", &min_gain_to_split); - CHECK(min_gain_to_split >= 0.0f) - GetInt(params, "num_leaves", &num_leaves); + CHECK(lambda_l1 >= 0.0f); + GetDouble(params, "lambda_l2", &lambda_l2); + CHECK(lambda_l2 >= 0.0f); + GetDouble(params, "min_gain_to_split", &min_gain_to_split); + CHECK(min_gain_to_split >= 0.0f); + GetInt(params, "num_leaves", &num_leaves); CHECK(num_leaves > 1); GetInt(params, "feature_fraction_seed", &feature_fraction_seed); GetDouble(params, "feature_fraction", &feature_fraction);