Skip to content

Commit

Permalink
[python-package] pass params of engine.train and engine.cv to Dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Dec 28, 2016
1 parent 292f972 commit 616388e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
6 changes: 6 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,12 @@ def create_valid(self, data, label=None, weight=None, group=None,
ret._set_predictor(self._predictor)
return ret

def _update_params(self, params):
if not self.params:
self.params = params
else:
self.params.update(params)

def construct(self):
"""
Lazy init
Expand Down
6 changes: 4 additions & 2 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def train(params, train_set, num_boost_round=100,
if not isinstance(train_set, Dataset):
raise TypeError("Traninig only accepts Dataset object")

train_set._update_params(params)
train_set._set_predictor(predictor)
train_set.set_feature_name(feature_name)
train_set.set_categorical_feature(categorical_feature)
Expand Down Expand Up @@ -120,7 +121,8 @@ def train(params, train_set, num_boost_round=100,
name_valid_sets.append(valid_names[i])
else:
name_valid_sets.append('valid_'+str(i))

for valid_data in valid_sets:
valid_data._update_params(params)
"""process callbacks"""
if callbacks is None:
callbacks = set()
Expand Down Expand Up @@ -332,7 +334,7 @@ def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
predictor = init_model._to_predictor()
else:
predictor = None

train_set._update_params(params)
train_set._set_predictor(predictor)
train_set.set_feature_name(feature_name)
train_set.set_categorical_feature(categorical_feature)
Expand Down
12 changes: 6 additions & 6 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf > 1.0f || min_data_in_leaf > 0);
GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >= 0.0f)
GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >= 0.0f)
GetDouble(params, "min_gain_to_split", &min_gain_to_split);
CHECK(min_gain_to_split >= 0.0f)
GetInt(params, "num_leaves", &num_leaves);
CHECK(lambda_l1 >= 0.0f);
GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >= 0.0f);
GetDouble(params, "min_gain_to_split", &min_gain_to_split);
CHECK(min_gain_to_split >= 0.0f);
GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves > 1);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
GetDouble(params, "feature_fraction", &feature_fraction);
Expand Down

0 comments on commit 616388e

Please sign in to comment.