Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Categorical feature support #108

Merged
merged 26 commits into from
Dec 5, 2016
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
261145a
The logic for tree prediction
guolinke Dec 1, 2016
a707174
fix error for uint8_t to string
guolinke Dec 1, 2016
762c2aa
some bug fix
guolinke Dec 1, 2016
40ee200
update logic for costructing bin mapper
guolinke Dec 1, 2016
90a1fe2
add main logic for find best threshold
guolinke Dec 1, 2016
65b7739
fix bugs
guolinke Dec 1, 2016
7de3551
save/load feature_names to model file. expose c_api of set feature_names
guolinke Dec 2, 2016
790b53b
merge from master
guolinke Dec 2, 2016
0f5e9dc
use function pointer to avoid if..else
guolinke Dec 2, 2016
d997e54
update dataset_loader to support specific categorical feature
guolinke Dec 2, 2016
ce58581
some warnings fixed
guolinke Dec 3, 2016
87e3425
update format in dump_model
guolinke Dec 3, 2016
8ffa9f6
use std::function to avoid branching
guolinke Dec 3, 2016
f5795dd
fix json format
guolinke Dec 3, 2016
2e2a4b2
reduce memory cost for feature histogram
guolinke Dec 3, 2016
6e03b7d
fix json format
guolinke Dec 3, 2016
4b8b964
update join function
guolinke Dec 3, 2016
1bdfe24
support set categorical feature in python package
guolinke Dec 3, 2016
ae672b8
fix bug that using std::numeric_limits<int>::infinity().
guolinke Dec 3, 2016
a6acade
update column parse logic
guolinke Dec 3, 2016
75bd396
some naming fix
guolinke Dec 3, 2016
1510033
fix one bug in continue train
guolinke Dec 4, 2016
6f044ac
bug fix in str2tree.
guolinke Dec 4, 2016
158ddd5
refine some template functions
guolinke Dec 4, 2016
0a7712b
comment refine
guolinke Dec 5, 2016
80c1fa8
move 'other_categorical' to last bin
guolinke Dec 5, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix bugs
  • Loading branch information
guolinke committed Dec 1, 2016
commit 65b7739b06ff89709757fd156241d85f6d82b758
6 changes: 3 additions & 3 deletions src/io/dense_bin.hpp
Original file line number Diff line number Diff line change
@@ -172,16 +172,16 @@ template <typename VAL_T>
class DenseCategoricalBin: public DenseBin<VAL_T> {
public:
DenseCategoricalBin(data_size_t num_data, int default_bin)
: DenseBin(num_data, default_bin) {
: DenseBin<VAL_T>(num_data, default_bin) {
}

data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data,
virtual data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
data_size_t lte_count = 0;
data_size_t gt_count = 0;
for (data_size_t i = 0; i < num_data; ++i) {
data_size_t idx = data_indices[i];
if (data_[idx] != threshold) {
if (DenseBin<VAL_T>::data_[idx] != threshold) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
4 changes: 2 additions & 2 deletions src/io/sparse_bin.hpp
Original file line number Diff line number Diff line change
@@ -303,10 +303,10 @@ template <typename VAL_T>
class SparseCategoricalBin: public SparseBin<VAL_T> {
public:
SparseCategoricalBin(data_size_t num_data, int default_bin)
: SparseBin(num_data, default_bin) {
: SparseBin<VAL_T>(num_data, default_bin) {
}

data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data,
virtual data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
// not need to split
if (num_data <= 0) { return 0; }
70 changes: 40 additions & 30 deletions src/treelearner/feature_histogram.hpp
Original file line number Diff line number Diff line change
@@ -159,19 +159,24 @@ class FeatureHistogram {
best_gain = current_gain;
}
}
// update split information
output->feature = feature_idx_;
output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian);
output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian;
output->right_output = CalculateSplittedLeafOutput(sum_gradients_ - best_sum_left_gradient,
sum_hessians_ - best_sum_left_hessian);
output->right_count = num_data_ - best_left_count;
output->right_sum_gradient = sum_gradients_ - best_sum_left_gradient;
output->right_sum_hessian = sum_hessians_ - best_sum_left_hessian;
output->gain = best_gain - gain_shift;
if (best_threshold < static_cast<unsigned int>(num_bins_)) {
// update split information
output->feature = feature_idx_;
output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian);
output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian;
output->right_output = CalculateSplittedLeafOutput(sum_gradients_ - best_sum_left_gradient,
sum_hessians_ - best_sum_left_hessian);
output->right_count = num_data_ - best_left_count;
output->right_sum_gradient = sum_gradients_ - best_sum_left_gradient;
output->right_sum_hessian = sum_hessians_ - best_sum_left_hessian;
output->gain = best_gain - gain_shift;
} else {
output->feature = feature_idx_;
output->gain = kMinScore;
}
}

/*!
@@ -194,11 +199,11 @@ class FeatureHistogram {
if (current_count < min_num_data_one_leaf_ || sum_current_hessian < min_sum_hessian_one_leaf_) continue;
data_size_t other_count = num_data_ - current_count;
// if data not enough
if (other_count < min_num_data_one_leaf_) break;
if (other_count < min_num_data_one_leaf_) continue;

double sum_other_hessian = sum_hessians_ - sum_current_hessian;
// if sum hessian too small
if (sum_other_hessian < min_sum_hessian_one_leaf_) break;
if (sum_other_hessian < min_sum_hessian_one_leaf_) continue;

double sum_other_gradient = sum_gradients_ - sum_current_gradient;
// current split gain
@@ -216,21 +221,26 @@ class FeatureHistogram {
}
}
// update split information
output->feature = feature_idx_;
output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(data_[best_threshold].sum_gradients,
data_[best_threshold].sum_hessians);
output->left_count = data_[best_threshold].cnt;
output->left_sum_gradient = data_[best_threshold].sum_gradients;
output->left_sum_hessian = data_[best_threshold].sum_hessians;

output->right_output = CalculateSplittedLeafOutput(sum_gradients_ - data_[best_threshold].sum_gradients,
sum_hessians_ - data_[best_threshold].sum_hessians);
output->right_count = num_data_ - data_[best_threshold].cnt;
output->right_sum_gradient = sum_gradients_ - data_[best_threshold].sum_gradients;
output->right_sum_hessian = sum_hessians_ - data_[best_threshold].sum_hessians;

output->gain = best_gain - gain_shift;
if (best_threshold < static_cast<unsigned int>(num_bins_)) {
output->feature = feature_idx_;
output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(data_[best_threshold].sum_gradients,
data_[best_threshold].sum_hessians);
output->left_count = data_[best_threshold].cnt;
output->left_sum_gradient = data_[best_threshold].sum_gradients;
output->left_sum_hessian = data_[best_threshold].sum_hessians;

output->right_output = CalculateSplittedLeafOutput(sum_gradients_ - data_[best_threshold].sum_gradients,
sum_hessians_ - data_[best_threshold].sum_hessians);
output->right_count = num_data_ - data_[best_threshold].cnt;
output->right_sum_gradient = sum_gradients_ - data_[best_threshold].sum_gradients;
output->right_sum_hessian = sum_hessians_ - data_[best_threshold].sum_hessians;

output->gain = best_gain - gain_shift;
} else {
output->feature = feature_idx_;
output->gain = kMinScore;
}
}

/*!