Skip to content

Uniform handling of y_train in forecasting performance metrics #712

Closed
@mloning

Description

Is your feature request related to a problem? Please describe.
Some forecasting performance metrics require y_train, e.g. MASE. This adds some complication to higher-level functionality that expects a common interface for metrics, like the evaluate function or ForecastingGridSearchCV, and in unit testing (see #672).

Current problem
This currently fails because y_train is not passed internally when calling scoring.

from sktime.forecasting.all import *
from sktime.forecasting.model_evaluation import evaluate

y = load_airline()
f = NaiveForecaster()
cv = SlidingWindowSplitter()
scoring = MASE()
out = evaluate(f, cv, y, scoring=scoring)

Possible solutions

  1. Change interface for all performance metrics to optionally accept y_train, but only those that require it use it. This requires wrapping metrics from scikit-learn.
  2. Add case distinctions in higher-level functionality to separately handle those metrics that require y_train and those that do not. This requires adding a requires_y_train attribute to metric classes.
  3. Adapt metrics interface at run time to, making case distinctions inside adapter, exposing uniform interface to higher-order functionality (suggested by @fkiraly). This also requires adding a requires_y_train attribute to metric classes.

Describe the solution you'd like

from sktime.forecasting.all import *
from sktime.forecasting.model_evaluation import evaluate

y = load_airline()
fh = np.arange(1, 10)
y_train, y_test = temporal_train_test_split(y, fh=fh)
f = NaiveForecaster()
f.fit(y_train)
y_pred = f.predict(fh)

# uniform interface
scoring = MASE()
scoring.requires_y_train = True
scoring = check_scoring(scoring)
scoring(y_test, y_pred, y_train)
>>> 3.577770878609128

scoring = sMAPE()
scoring.requires_y_train = False
scoring = check_scoring(scoring)
scoring(y_test, y_pred, y_train)
>>> 0.1780237534499896

Here's a rough implementation of the adapter-based solution:

class _MetricAdapter:
    """
    Adapter for performance metrics to uniformly handle 
    y_train requirement of some metrics.
    """

    def __init__(self, metric):
        # wrap metric object
        self.metric = metric
        
    def __call__(self, y_true, y_pred, y_train, *args, **kwargs):
        """Compute metric, uniformly handling those metrics that 
        require `y_train` and those that do not.
        """
        
        # if y_train is required, pass it on
        if self.metric.requires_y_train:
            return self.metric(y_true, y_pred, y_train, *args, **kwargs)
        
        # otherwise, ignore y_train
        else:
            return self.metric(y_true, y_pred, *args, **kwargs)     
        
    def __getattr__(self, attr):
        # delegate attribute queries to the wrapped metric object
        return getattr(self.metric, attr)

    def __repr___(self):
        return repr(self.metric)

    
def _adapt_scoring(scoring):
    """Helper function to adapt scoring to uniformly handle y_train requirement"""
    return MetricAdapter(scoring)


def check_scoring(scoring):
    """
    Validate `scoring` object.

    Parameters
    ----------
    scoring : object
        Callable metric object.
    
    Returns
    -------
    scoring : object 
        Validated `scoring` object, or sMAPE() if `scoring` is None.
    
    Raises
    ------
    TypeError
        If `scoring` is not a callable object.
    """
    from sktime.performance_metrics.forecasting import sMAPE
    from sktime.performance_metrics.forecasting._classes import MetricFunctionWrapper
    
    if scoring is None:
        return sMAPE()

    if not callable(scoring):
        raise TypeError("`scoring` must be a callable object")

    valid_base_class = MetricFunctionWrapper
    if not isinstance(scoring, valid_base_class):
        raise TypeError(f"`scoring` must inherit from `{valid_base_class.__name__}`")

    return _adapt_scoring(scoring)

Metadata

Assignees

No one assigned

    Labels

    API designAPI design & software architecturefeature requestNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions