Skip to content

Commit

Permalink
Make operator TreeEnsemble 5x faster for batches of size 100.000 (mic…
Browse files Browse the repository at this point in the history
…rosoft#5965)

* improves processing time by 10
* extend coverage unit test coverage
* better implementation for the multi regression case
* better comment, keep parallelization by trees when not enough trees
  • Loading branch information
xadupre authored Dec 3, 2020
1 parent 524b9fa commit 0acc383
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ template <typename T>
TreeEnsembleClassifier<T>::TreeEnsembleClassifier(const OpKernelInfo& info)
: OpKernel(info),
tree_ensemble_(
100,
80,
50,
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
info.GetAttrsOrDefault<float>("base_values"),
Expand Down
224 changes: 139 additions & 85 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,126 +262,180 @@ void TreeEnsembleCommon<ITYPE, OTYPE>::ComputeAgg(concurrency::ThreadPool* ttp,
const ITYPE* x_data = X->template Data<ITYPE>();
OTYPE* z_data = Z->template MutableData<OTYPE>();
int64_t* label_data = label == nullptr ? nullptr : label->template MutableData<int64_t>();
auto max_num_threads = concurrency::ThreadPool::DegreeOfParallelism(ttp);

if (n_targets_or_classes_ == 1) {
if (N == 1) {
ScoreValue<OTYPE> score = {0, 0};
if (n_trees_ <= parallel_tree_) {
if (n_trees_ <= parallel_tree_) { /* section A: 1 output, 1 row and not enough trees to parallelize */
for (int64_t j = 0; j < n_trees_; ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data));
}
} else {
std::vector<ScoreValue<OTYPE>> scores_t(n_trees_, {0, 0});
} else { /* section B: 1 output, 1 row and enough trees to parallelize */
std::vector<ScoreValue<OTYPE>> scores(n_trees_, {0, 0});
concurrency::ThreadPool::TryBatchParallelFor(
ttp,
SafeInt<int32_t>(n_trees_),
[this, &scores_t, &agg, x_data](ptrdiff_t j) {
agg.ProcessTreeNodePrediction1(scores_t[j], *ProcessTreeNodeLeave(roots_[j], x_data));
[this, &scores, &agg, x_data](ptrdiff_t j) {
agg.ProcessTreeNodePrediction1(scores[j], *ProcessTreeNodeLeave(roots_[j], x_data));
},
0);

for (auto it = scores_t.cbegin(); it != scores_t.cend(); ++it) {
for (auto it = scores.cbegin(); it != scores.cend(); ++it) {
agg.MergePrediction1(score, *it);
}
}

agg.FinalizeScores1(z_data, score, label_data);
} else {
if (N <= parallel_N_) {
ScoreValue<OTYPE> score;
size_t j;

for (int64_t i = 0; i < N; ++i) {
score = {0, 0};
for (j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores1(z_data + i * n_targets_or_classes_, score,
label_data == nullptr ? nullptr : (label_data + i));
} else if (N <= parallel_N_) { /* section C: 1 output, 2+ rows but not enough rows to parallelize */
ScoreValue<OTYPE> score;
size_t j;

for (int64_t i = 0; i < N; ++i) {
score = {0, 0};
for (j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
} else {
concurrency::ThreadPool::TryBatchParallelFor(
ttp,
SafeInt<int32_t>(N),
[this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) {
ScoreValue<OTYPE> score = {0, 0};
for (size_t j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores1(z_data + i * n_targets_or_classes_, score,
label_data == nullptr ? nullptr : (label_data + i));
},
0);
agg.FinalizeScores1(z_data + i, score,
label_data == nullptr ? nullptr : (label_data + i));
}
} else if (n_trees_ > max_num_threads) { /* section D: 1 output, 2+ rows and enough trees to parallelize */
auto num_threads = std::min<int32_t>(max_num_threads, SafeInt<int32_t>(n_trees_));
std::vector<ScoreValue<OTYPE>> scores(num_threads * N);
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, num_threads, x_data, N, stride](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, this->n_trees_);
for (int64_t i = 0; i < N; ++i) {
scores[batch_num * N + i] = {0, 0};
}
for (auto j = work.start; j < work.end; ++j) {
for (int64_t i = 0; i < N; ++i) {
agg.ProcessTreeNodePrediction1(scores[batch_num * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
}
});

concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[&agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);
for (auto i = work.start; i < work.end; ++i) {
for (int64_t j = 1; j < num_threads; ++j) {
agg.MergePrediction1(scores[i], scores[j * N + i]);
}
agg.FinalizeScores1(z_data + i, scores[i],
label_data == nullptr ? nullptr : (label_data + i));
}
});
} else { /* section E: 1 output, 2+ rows, parallelization by rows */
concurrency::ThreadPool::TryBatchParallelFor(
ttp,
SafeInt<int32_t>(N),
[this, &agg, x_data, z_data, stride, label_data](ptrdiff_t i) {
ScoreValue<OTYPE> score = {0, 0};
for (size_t j = 0; j < static_cast<size_t>(n_trees_); ++j) {
agg.ProcessTreeNodePrediction1(score, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores1(z_data + i, score,
label_data == nullptr ? nullptr : (label_data + i));
},
0);
}
} else {
if (N == 1) {
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_, {0, 0});
if (n_trees_ <= parallel_tree_) {
if (N == 1) { /* section A2: 2+ outputs, 1 row, not enough trees to parallelize */
if (n_trees_ <= parallel_tree_) { /* section A2 */
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_, {0, 0});
for (int64_t j = 0; j < n_trees_; ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data));
}
} else {
// split the work into one block per thread so we can re-use the 'private_scores' vector as much as possible
// TODO: Refine the number of threads used
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(n_trees_));
OrtMutex merge_mutex;
agg.FinalizeScores(scores, z_data, -1, label_data);
} else { /* section B2: 2+ outputs, 1 row, enough trees to parallelize */
auto num_threads = std::min<int32_t>(max_num_threads, SafeInt<int32_t>(n_trees_));
std::vector<std::vector<ScoreValue<OTYPE>>> scores(num_threads);
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, &merge_mutex, num_threads, x_data](ptrdiff_t batch_num) {
std::vector<ScoreValue<OTYPE>> private_scores(n_targets_or_classes_, {0, 0});
[this, &agg, &scores, num_threads, x_data](ptrdiff_t batch_num) {
scores[batch_num].resize(n_targets_or_classes_, {0, 0});
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, n_trees_);
for (auto j = work.start; j < work.end; ++j) {
agg.ProcessTreeNodePrediction(private_scores, *ProcessTreeNodeLeave(roots_[j], x_data));
agg.ProcessTreeNodePrediction(scores[batch_num], *ProcessTreeNodeLeave(roots_[j], x_data));
}

std::lock_guard<OrtMutex> lock(merge_mutex);
agg.MergePrediction(scores, private_scores);
});
for (size_t i = 1; i < scores.size(); ++i) {
agg.MergePrediction(scores[0], scores[i]);
}
agg.FinalizeScores(scores[0], z_data, -1, label_data);
}

agg.FinalizeScores(scores, z_data, -1, label_data);
} else {
if (N <= parallel_N_) {
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
size_t j;

for (int64_t i = 0; i < N; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
} else if (N <= parallel_N_) { /* section C2: 2+ outputs, 2+ rows, not enough rows to parallelize */
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
size_t j;

for (int64_t i = 0; i < N; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
} else {
// split the work into one block per thread so we can re-use the 'scores' vector as much as possible
// TODO: Refine the number of threads used.
auto num_threads = std::min<int32_t>(concurrency::ThreadPool::DegreeOfParallelism(ttp), SafeInt<int32_t>(N));
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, num_threads, x_data, z_data, label_data, N, stride](ptrdiff_t batch_num) {
size_t j;
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);

for (auto i = work.start; i < work.end; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores(scores,
z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
});

agg.FinalizeScores(scores, z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
} else if (n_trees_ >= max_num_threads) { /* section: D2: 2+ outputs, 2+ rows, enough trees to parallelize*/
auto num_threads = std::min<int32_t>(max_num_threads, SafeInt<int32_t>(n_trees_));
std::vector<std::vector<ScoreValue<OTYPE>>> scores(num_threads * N);
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, num_threads, x_data, N, stride](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, this->n_trees_);
for (int64_t i = 0; i < N; ++i) {
scores[batch_num * N + i].resize(n_targets_or_classes_, {0, 0});
}
for (auto j = work.start; j < work.end; ++j) {
for (int64_t i = 0; i < N; ++i) {
agg.ProcessTreeNodePrediction(scores[batch_num * N + i], *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}
}
});

concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, &scores, num_threads, label_data, z_data, N](ptrdiff_t batch_num) {
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);
for (auto i = work.start; i < work.end; ++i) {
for (int64_t j = 1; j < num_threads; ++j) {
agg.MergePrediction(scores[i], scores[j * N + i]);
}
agg.FinalizeScores(scores[i], z_data + i * this->n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
});
} else { /* section E2: 2+ outputs, 2+ rows, parallelization by rows */
auto num_threads = std::min<int32_t>(max_num_threads, SafeInt<int32_t>(N));
concurrency::ThreadPool::TrySimpleParallelFor(
ttp,
num_threads,
[this, &agg, num_threads, x_data, z_data, label_data, N, stride](ptrdiff_t batch_num) {
size_t j;
std::vector<ScoreValue<OTYPE>> scores(n_targets_or_classes_);
auto work = concurrency::ThreadPool::PartitionWork(batch_num, num_threads, N);

for (auto i = work.start; i < work.end; ++i) {
std::fill(scores.begin(), scores.end(), ScoreValue<OTYPE>({0, 0}));
for (j = 0; j < roots_.size(); ++j) {
agg.ProcessTreeNodePrediction(scores, *ProcessTreeNodeLeave(roots_[j], x_data + i * stride));
}

agg.FinalizeScores(scores,
z_data + i * n_targets_or_classes_, -1,
label_data == nullptr ? nullptr : (label_data + i));
}
});
}
}
} // namespace detail
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/ml/treeregressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template <typename T>
TreeEnsembleRegressor<T>::TreeEnsembleRegressor(const OpKernelInfo& info)
: OpKernel(info),
tree_ensemble_(
100,
80,
50,
info.GetAttrOrDefault<std::string>("aggregate_function", "SUM"),
info.GetAttrsOrDefault<float>("base_values"),
Expand Down
Loading

0 comments on commit 0acc383

Please sign in to comment.