Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change functions in common.h into template functions (#969) #973

Merged
merged 6 commits into from
Oct 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,20 +580,24 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
}

// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
inline void CheckElementsIntervalClosed(const float *y, float ymin, float ymax, int ny, const char *callername) {
template <typename T>
inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
for (int i = 0; i < ny; ++i) {
if (y[i] < ymin || y[i] > ymax) {
Log::Fatal("[%s]: does not tolerate element [#%i = %f] outside [%f, %f]", callername, i, y[i], ymin, ymax);
std::ostringstream os;
os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
Log::Fatal(os.str().c_str(), callername, i);
}
}
}

// One-pass scan over array w with nw elements: find min, max and sum of elements;
// this is useful for checking weight requirements.
inline void ObtainMinMaxSum(const float *w, int nw, float *mi, float *ma, double *su) {
float minw = w[0];
float maxw = w[0];
double sumw = static_cast<double>(w[0]);
template <typename T1, typename T2>
inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
T1 minw = w[0];
T1 maxw = w[0];
T2 sumw = static_cast<T2>(w[0]);
for (int i = 1; i < nw; ++i) {
sumw += w[i];
if (w[i] < minw) minw = w[i];
Expand Down
6 changes: 3 additions & 3 deletions src/metric/xentropy_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class CrossEntropyMetric : public Metric {
sum_weights_ = static_cast<double>(num_data_);
} else {
float minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, &sum_weights_);
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sum_weights_);
if (minw < 0.0f) {
Log::Fatal("[%s:%s]: (metric) weights not allowed to be negative", GetName()[0].c_str(), __func__);
}
Expand Down Expand Up @@ -178,7 +178,7 @@ class CrossEntropyLambdaMetric : public Metric {
// check all weights are strictly positive; throw error if not
if (weights_ != nullptr) {
float minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, nullptr);
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, (float*)nullptr);
if (minw <= 0.0f) {
Log::Fatal("[%s:%s]: (metric) all weights must be positive", GetName()[0].c_str(), __func__);
}
Expand Down Expand Up @@ -263,7 +263,7 @@ class KullbackLeiblerDivergence : public Metric {
sum_weights_ = static_cast<double>(num_data_);
} else {
float minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, &sum_weights_);
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sum_weights_);
if (minw < 0.0f) {
Log::Fatal("[%s:%s]: (metric) at least one weight is negative", GetName()[0].c_str(), __func__);
}
Expand Down
2 changes: 1 addition & 1 deletion src/objective/regression_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class RegressionPoissonLoss: public ObjectiveFunction {
// Safety check of labels
float miny;
double sumy;
Common::ObtainMinMaxSum(label_, num_data_, &miny, nullptr, &sumy);
Common::ObtainMinMaxSum(label_, num_data_, &miny, (float*)nullptr, &sumy);
if (miny < 0.0f) {
Log::Fatal("[%s]: at least one target label is negative.", GetName());
}
Expand Down
4 changes: 2 additions & 2 deletions src/objective/xentropy_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CrossEntropy: public ObjectiveFunction {
if (weights_ != nullptr) {
float minw;
double sumw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, &sumw);
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sumw);
if (minw < 0.0f) {
Log::Fatal("[%s]: at least one weight is negative.", GetName());
}
Expand Down Expand Up @@ -163,7 +163,7 @@ class CrossEntropyLambda: public ObjectiveFunction {

if (weights_ != nullptr) {

Common::ObtainMinMaxSum(weights_, num_data_, &min_weight_, &max_weight_, nullptr);
Common::ObtainMinMaxSum(weights_, num_data_, &min_weight_, &max_weight_, (float*)nullptr);
if (min_weight_ <= 0.0f) {
Log::Fatal("[%s]: at least one weight is non-positive.", GetName());
}
Expand Down