Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Make skorch work with sklearn 1.6.0 #1076

Closed
wants to merge 1 commit into from

Conversation

BenjaminBossan
Copy link
Collaborator

@BenjaminBossan BenjaminBossan commented Dec 11, 2024

The new sklearn version now requires the estimator to expose a method __sklearn_tags__. This PR implements a bare minimum of sklearn tags to make the tests pass again.

Probably this list of tags can be improved to be more precise.

Moreover, some tests now started failing. To be precise, using GridSearchCV with y being a torch tensor fails, as sklearn performs a check on the devices of y vs y_pred (a numpy array) and determines that the devices differ. The error message is:

Error
skorch/tests/test_helper.py:470: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/base.py:1389: in wrapper
    return fit_method(estimator, *args, **kwargs)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1023: in fit
    self._run_search(evaluate_candidates)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1570: in _run_search
    evaluate_candidates(ParameterGrid(self.param_grid))
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/model_selection/_search.py:969: in evaluate_candidates
    out = parallel(
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/utils/parallel.py:77: in __call__
    return super().__call__(iterable_with_config)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/joblib/parallel.py:1918: in __call__
    return output if self.return_generator else list(output)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/joblib/parallel.py:1847: in _get_sequential_output
    res = func(*args, **kwargs)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/utils/parallel.py:139: in __call__
    return self.function(*args, **kwargs)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/model_selection/_validation.py:888: in _fit_and_score
    test_scores = _score(
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/model_selection/_validation.py:949: in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/metrics/_scorer.py:288: in __call__
    return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/metrics/_scorer.py:388: in _score
    return self._sign * self._score_func(y_true, y_pred, **scoring_kwargs)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/utils/_param_validation.py:216: in wrapper
    return func(*args, **kwargs)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/metrics/_classification.py:224: in accuracy_score
    xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)
../../anaconda3/envs/skorch/lib/python3.12/site-packages/sklearn/utils/_array_api.py:614: in get_namespace_and_device
    arrays_device = device(*array_list, **skip_remove_kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

remove_none = False, remove_types = []
array_list = [tensor([1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1,
        1, 0, 1, 0, 0, 1, 0, 0, 0, 0]), array([1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1,
       0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1])]
device_ = device(type='cpu')

    def device(*array_list, remove_none=True, remove_types=(str,)):
        """Hardware device where the array data resides on.
    
        If the hardware device is not the same for all arrays, an error is raised.
    
        Parameters
        ----------
        *array_list : arrays
            List of array instances from NumPy or an array API compatible library.
    
        remove_none : bool, default=True
            Whether to ignore None objects passed in array_list.
    
        remove_types : tuple or list, default=(str,)
            Types to ignore in array_list.
    
        Returns
        -------
        out : device
            `device` object (see the "Device Support" section of the array API spec).
        """
        array_list = _remove_non_arrays(
            *array_list, remove_none=remove_none, remove_types=remove_types
        )
    
        if not array_list:
            return None
    
        device_ = _single_array_device(array_list[0])
    
        # Note: here we cannot simply use a Python `set` as it requires
        # hashable members which is not guaranteed for Array API device
        # objects. In particular, CuPy devices are not hashable at the
        # time of writing.
        for array in array_list[1:]:
            device_other = _single_array_device(array)
            if device_ != device_other:
>               raise ValueError(
                    f"Input arrays use different devices: {str(device_)}, "
                    f"{str(device_other)}"
                )
E               ValueError: Input arrays use different devices: cpu, cpu

The tests have now been amended to cast y to a numpy array. It is unfortunate that this used not be necessary but now is.

I think we can live with that, as it's an edge case. I tested that with normal fitting, passing a torch.tensor as y is no problem.

Note: The commit message says sklearn 1.5 but it's in fact 1.6.

The new sklearn version now requires the estimator to expose a method
__sklearn_tags__. This PR implements a bare minimum of sklearn tags to
make the tests pass again.

Probably this list of tags can be improved to be more precise.

Moreover, some tests now started failing. To be precise, using
GridSearchCV with y being a torch tensor fails, as sklearn performs a
check on the devices of y vs y_pred (a numpy array) and determines that
the devices differ. The error message is:

    def device(*array_list, remove_none=True, remove_types=(str,)):
        """Hardware device where the array data resides on.

        If the hardware device is not the same for all arrays, an error is raised.

        Parameters
        ----------
        *array_list : arrays
            List of array instances from NumPy or an array API compatible library.

        remove_none : bool, default=True
            Whether to ignore None objects passed in array_list.

        remove_types : tuple or list, default=(str,)
            Types to ignore in array_list.

        Returns
        -------
        out : device
            `device` object (see the "Device Support" section of the array API spec).
        """
        array_list = _remove_non_arrays(
            *array_list, remove_none=remove_none, remove_types=remove_types
        )

        if not array_list:
            return None

        device_ = _single_array_device(array_list[0])

        # Note: here we cannot simply use a Python `set` as it requires
        # hashable members which is not guaranteed for Array API device
        # objects. In particular, CuPy devices are not hashable at the
        # time of writing.
        for array in array_list[1:]:
            device_other = _single_array_device(array)
            if device_ != device_other:
>               raise ValueError(
                    f"Input arrays use different devices: {str(device_)}, "
                    f"{str(device_other)}"
                )
E               ValueError: Input arrays use different devices: cpu, cpu

The tests have now been amended to cast y to a numpy array. It is
unfortunate that this used not be necessary but now is.

I think we can live with that, as it's an edge case. I tested that with
normal fitting, passing a torch.tensor as y is no problem.
@BenjaminBossan
Copy link
Collaborator Author

Ping @githubnemo.

Maybe @adrinjalali could also check.

Copy link

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might also be interested in some of the utils we've put here: https://github.com/sklearn-compat/sklearn-compat/

def __sklearn_tags__(self):
# TODO: this is just the bare minimum, more tags should be added
tags = super().__sklearn_tags__()
tags.estimator_type = 'classifier'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should probably be inheriting from ClassifierMixin instead really.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, we don't inherit from BaseEstimator, ClassifierMixin, etc. So far, this went well but probably can cause issues like this. Would you think it is better to switch?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you're already inherriting from ClassifierMixin in this file, so adding BaseEstimator would only make things easier.

In general, we've added a substantial amount of features to the BaseEstimator, and you probably would have a much easier time if you inherit from it. These days I'd strongly recommend inheriting from the right mixins.

BenjaminBossan added a commit that referenced this pull request Dec 16, 2024
Altenative to #1076

As described in that PR, skorch is currently not compatible with sklearn
1.6.0 or above. As per suggestion, instead of implementing
__sklearn_tags__, this PR solves the issue by inheriting from
BaseEstimator.

Related changes:

- It is important to set the correct order when inheriting from
  BaseEstimator and, say, ClassifierMixin (BaseEstimator should come
  last).
- As explained in #1076, using GridSearchCV with y being a torch tensor
  fails and two tests had to be adjusted.

Unrelated changes

- Removed unnecessary imports from callbacks/base.py.
@BenjaminBossan
Copy link
Collaborator Author

Closing in favor of #1078.

@BenjaminBossan BenjaminBossan deleted the fix-implement-sklearn-tags branch December 16, 2024 13:19
BenjaminBossan added a commit that referenced this pull request Dec 18, 2024
Alternative to #1076

As described in that PR, skorch is currently not compatible with sklearn
1.6.0 or above. As per suggestion, instead of implementing
__sklearn_tags__, this PR solves the issue by inheriting from
BaseEstimator.

Related changes:

- It is important to set the correct order when inheriting from
  BaseEstimator and, say, ClassifierMixin (BaseEstimator should come
  last).
- As explained in #1076, using GridSearchCV with y being a torch tensor
  currently fails and two tests had to be adjusted.

Unrelated changes

- Removed unnecessary imports from callbacks/base.py.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants