Skip to content

Commit

Permalink
Parallelize TreeEnsembleClassifier batch predition (microsoft#1276)
Browse files Browse the repository at this point in the history
* use openmp for loop

* Fix windows compile err

* fix windows com err
  • Loading branch information
RandySheriffH authored Jun 25, 2019
1 parent a56b294 commit c0cf221
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,20 +358,23 @@ common::Status TreeEnsembleClassifier<T>::Compute(OpKernelContext* context) cons
int64_t N = x_dims.size() == 1 ? 1 : x_dims[0];
Tensor* Y = context->Output(0, TensorShape({N}));
auto* Z = context->Output(1, TensorShape({N, class_count_}));

int64_t zindex = 0;
const T* x_data = X.template Data<T>();

// for each class
std::vector<float> scores;
scores.reserve(class_count_);
common::Status status;
#ifdef USE_OPENMP
#pragma omp parallel for
#endif
for (int64_t i = 0; i < N; ++i) {
scores.clear();
int64_t zindex = i * class_count_;
std::vector<float> scores;
int64_t current_weight_0 = i * stride;
std::map<int64_t, float> classes;
// walk each tree from its root
for (size_t j = 0, end = roots_.size(); j < end; ++j) {
ORT_RETURN_IF_ERROR(ProcessTreeNode(classes, roots_[j], x_data, current_weight_0));
auto process_status = ProcessTreeNode(classes, roots_[j], x_data, current_weight_0);
if (!process_status.IsOK()) {
status = process_status;
}
}
float maxweight = 0.f;
int64_t maxclass = -1;
Expand Down Expand Up @@ -442,9 +445,8 @@ common::Status TreeEnsembleClassifier<T>::Compute(OpKernelContext* context) cons
}
}
write_scores(scores, post_transform_, zindex, Z, write_additional_scores);
zindex += scores.size();
} // namespace ml
return Status::OK();
return status;
} // namespace ml

template <typename T>
Expand Down

0 comments on commit c0cf221

Please sign in to comment.