Skip to content

Commit

Permalink
Let TreeLearner share the same code of ConstructHistograms.
Browse files Browse the repository at this point in the history
guolinke committed Apr 5, 2017
1 parent b6c973a commit 98ffbb2
Showing 4 changed files with 29 additions and 46 deletions.
7 changes: 1 addition & 6 deletions src/treelearner/data_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
@@ -140,12 +140,7 @@ void DataParallelTreeLearner::BeforeTrain() {
}

void DataParallelTreeLearner::FindBestThresholds() {
train_data_->ConstructHistograms(is_feature_used_,
smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->LeafIndex(),
ordered_bins_, gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
smaller_leaf_histogram_array_[0].RawData() - 1);
ConstructHistograms(is_feature_used_, true);
// construct local histograms
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
47 changes: 26 additions & 21 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
@@ -417,25 +417,10 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
return true;
}

void SerialTreeLearner::FindBestThresholds() {
#ifdef TIMETAG
void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
#ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now();
#endif
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}
bool use_subtract = true;
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
#endif
// construct smaller leaf
HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used,
@@ -455,11 +440,31 @@ void SerialTreeLearner::FindBestThresholds() {
ordered_gradients_.data(), ordered_hessians_.data(),
ptr_larger_leaf_hist_data);
}
#ifdef TIMETAG
#ifdef TIMETAG
hist_time += std::chrono::steady_clock::now() - start_time;
#endif
#endif
}

void SerialTreeLearner::FindBestThresholds() {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!is_feature_used_[feature_index]) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
continue;
}
is_feature_used[feature_index] = 1;
}

bool use_subtract = true;
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
ConstructHistograms(is_feature_used, use_subtract);
#ifdef TIMETAG
start_time = std::chrono::steady_clock::now();
auto start_time = std::chrono::steady_clock::now();
#endif
std::vector<SplitInfo> smaller_best(num_threads_);
std::vector<SplitInfo> larger_best(num_threads_);
1 change: 1 addition & 0 deletions src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
@@ -69,6 +69,7 @@ class SerialTreeLearner: public TreeLearner {
*/
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);

void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);

/*!
* \brief Find best thresholds for all features, using multi-threading.
20 changes: 1 addition & 19 deletions src/treelearner/voting_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
@@ -260,25 +260,7 @@ void VotingParallelTreeLearner::FindBestThresholds() {
if (parent_leaf_histogram_array_ == nullptr) {
use_subtract = false;
}
// construct smaller leaf
HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used,
smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->LeafIndex(),
ordered_bins_, gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
ptr_smaller_leaf_hist_data);

if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
// construct larger leaf
HistogramBinEntry* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - 1;
train_data_->ConstructHistograms(is_feature_used,
larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_->LeafIndex(),
ordered_bins_, gradients_, hessians_,
ordered_gradients_.data(), ordered_hessians_.data(),
ptr_larger_leaf_hist_data);
}
ConstructHistograms(is_feature_used, use_subtract);

std::vector<SplitInfo> smaller_bestsplit_per_features(num_features_);
std::vector<SplitInfo> larger_bestsplit_per_features(num_features_);

0 comments on commit 98ffbb2

Please sign in to comment.