Skip to content

I cannot reproduce results of quantile regression when using a custom metric or objective. #6062

Open
@HowardRiddiough

Description

Description

I am working on a project where we want to make a conservative prediction, aka increase likelihood that the model produces a negative error.

At the moment we are using LightGBM's quantile regression to predict the 90th quantile, I am happy with the results but would like to tweak the loss function ever so slightly to introduce a quadratic loss function to punish large negative errors disproportionately more than small negative errors. As I mentioned earlier we are looking to make conservative predictions but a prediction that is too conservative doesn't deliver any value.

With that end in mind I have been trying to reproduce the stock quantile regressor results, I want to begin with reproducing stock behaviour so I know I have a good foundation to start modifying the quantile loss function from.

When running the examples below you will see that the predictions made by each model do not match. That may be because I haven't constructed the quantile loss function correctly. It may also be that my custom quantile regressor does not calculate tree output in the same way as the stock quantile regressor. I can see in the regression_objective.hpp file that the RegressionQuantileloss class is doing something with percentiles when calculating the tree output that the standard regression loss may not be doing.

Reproducible example

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from lightgbm import LGBMRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error

data_url = 'http://lib.stat.cmu.edu/datasets/boston'
raw_df = pd.read_csv(data_url, sep='\s+', skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]

X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)


# Construct a model using LightGBM's quantile regressor
lgbm_stock = LGBMRegressor(objective='quantile', alpha=0.9, metric='quantile')
lgbm_stock.fit(X_train, y_train)
y_pred_stock = lgbm_stock.predict(X_test)


# Construct a model using a custom objective designed to produce the same predictions as the stock quantile regressor
alpha = 0.9

def quantile_loss_objective(preds, labels):
    error = preds - labels
    gradient = (-alpha * (error < 0) * error - (1 - alpha) * (error >= 0)) * error
    hessian = np.ones_like(error) * alpha

    return gradient, hessian


lgbm_custom = LGBMRegressor(objective=quantile_loss_objective)
lgbm_custom.fit(X_train, y_train)
y_pred_custom = lgbm_custom.predict(X_test)

# Plot predictions on test
plt.hist(y_pred_stock, label='stock', alpha=0.5)
plt.hist(y_pred_custom, label='custom', alpha=0.5)
plt.legend()
plt.title("LightGBM's quantile regressor vs LightGBM regressor with quantile loss objective")
plt.show()

Environment info

python==3.9.5
numpy=1.23.5
matplotlib==3.7.1
lightgbm==3.3.5
pandas==1.5.3
scikit-learn==1.3.0

I am working on a macbook with an M1 chip.

Summary

Thanks for taking a look at this issue and I really would appreciate any help. My questions are as follows:

  • Is it possible to mimic the stock quantile regressor using a custom objective?
  • If it is not possible, how would you recommend constructing a model where large negative errors are punished disproportionately more than small negative errors whilst maximising the likelihood that the model over-predicts? My initial goal when starting on this path was to use a custom loss function that looks something like this:
def custom_regularized_quantile_loss(y_true, y_pred, alpha):
    error = y_true - y_pred
    regularizer = np.where(error >= 0, 1, -1) * error**2  # Adjust the regularizer function
    loss = -alpha * (error < 0) * error - (1 - alpha) * (error >= 0) * error + regularizer
    return np.mean(loss)

regressor-outputs

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions