Skip to content

Commit

Permalink
ENH Simplify pytest global random test plugin (scikit-learn#27963)
Browse files Browse the repository at this point in the history
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
  • Loading branch information
3 people authored Jun 28, 2024
1 parent 2107404 commit a4ebe19
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 105 deletions.
5 changes: 4 additions & 1 deletion build_tools/azure/test_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ if [[ "$BUILD_REASON" == "Schedule" ]]; then
# Enable global random seed randomization to discover seed-sensitive tests
# only on nightly builds.
# https://scikit-learn.org/stable/computing/parallelism.html#environment-variables
export SKLEARN_TESTS_GLOBAL_RANDOM_SEED="any"
export SKLEARN_TESTS_GLOBAL_RANDOM_SEED=$(($RANDOM % 100))
echo "To reproduce this test run, set the following environment variable:"
echo " SKLEARN_TESTS_GLOBAL_RANDOM_SEED=$SKLEARN_TESTS_GLOBAL_RANDOM_SEED",
echo "See: https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed"

# Enable global dtype fixture for all nightly builds to discover
# numerical-sensitive tests.
Expand Down
16 changes: 6 additions & 10 deletions doc/computing/parallelism.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,12 @@ the `global_random_seed`` fixture.
All tests that use this fixture accept the contract that they should
deterministically pass for any seed value from 0 to 99 included.

If the `SKLEARN_TESTS_GLOBAL_RANDOM_SEED` environment variable is set to
`"any"` (which should be the case on nightly builds on the CI), the fixture
will choose an arbitrary seed in the above range (based on the BUILD_NUMBER or
the current day) and all fixtured tests will run for that specific seed. The
goal is to ensure that, over time, our CI will run all tests with different
seeds while keeping the test duration of a single run of the full test suite
limited. This will check that the assertions of tests written to use this
fixture are not dependent on a specific seed value.
In nightly CI builds, the `SKLEARN_TESTS_GLOBAL_RANDOM_SEED` environment
variable is drawn randomly in the above range and all fixtured tests will run
for that specific seed. The goal is to ensure that, over time, our CI will run
all tests with different seeds while keeping the test duration of a single run
of the full test suite limited. This will check that the assertions of tests
written to use this fixture are not dependent on a specific seed value.

The range of admissible seed values is limited to [0, 99] because it is often
not possible to write a test that can work for any possible seed and we want to
Expand All @@ -250,8 +248,6 @@ Valid values for `SKLEARN_TESTS_GLOBAL_RANDOM_SEED`:
- `SKLEARN_TESTS_GLOBAL_RANDOM_SEED="42"`: run tests with a fixed seed of 42
- `SKLEARN_TESTS_GLOBAL_RANDOM_SEED="40-42"`: run the tests with all seeds
between 40 and 42 included
- `SKLEARN_TESTS_GLOBAL_RANDOM_SEED="any"`: run the tests with an arbitrary
seed selected between 0 and 99 included
- `SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all"`: run the tests with all seeds
between 0 and 99 included. This can take a long time: only use for individual
tests, not the full test suite!
Expand Down
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ addopts =
--doctest-modules
--disable-pytest-warnings
--color=yes
# Activate the plugin explicitly to ensure that the seed is reported
# correctly on the CI when running `pytest --pyargs sklearn` from the
# source folder.
-p sklearn.tests.random_seed

[mypy]
ignore_missing_imports = True
Expand Down
50 changes: 45 additions & 5 deletions sklearn/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
fetch_rcv1,
fetch_species_distributions,
)
from sklearn.tests import random_seed
from sklearn.utils._testing import get_pytest_filterwarning_lines
from sklearn.utils.fixes import (
_IS_32BIT,
Expand Down Expand Up @@ -265,6 +264,51 @@ def pyplot():
pyplot.close("all")


def pytest_generate_tests(metafunc):
"""Parametrization of global_random_seed fixture
based on the SKLEARN_TESTS_GLOBAL_RANDOM_SEED environment variable.
The goal of this fixture is to prevent tests that use it to be sensitive
to a specific seed value while still being deterministic by default.
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
variable for instructions on how to use this fixture.
https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed
"""
# When using pytest-xdist this function is called in the xdist workers.
# We rely on SKLEARN_TESTS_GLOBAL_RANDOM_SEED environment variable which is
# set in before running pytest and is available in xdist workers since they
# are subprocesses.
RANDOM_SEED_RANGE = list(range(100)) # All seeds in [0, 99] should be valid.
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED")

default_random_seeds = [42]

if random_seed_var is None:
random_seeds = default_random_seeds
elif random_seed_var == "all":
random_seeds = RANDOM_SEED_RANGE
else:
if "-" in random_seed_var:
start, stop = random_seed_var.split("-")
random_seeds = list(range(int(start), int(stop) + 1))
else:
random_seeds = [int(random_seed_var)]

if min(random_seeds) < 0 or max(random_seeds) > 99:
raise ValueError(
"The value(s) of the environment variable "
"SKLEARN_TESTS_GLOBAL_RANDOM_SEED must be in the range [0, 99] "
f"(or 'all'), got: {random_seed_var}"
)

if "global_random_seed" in metafunc.fixturenames:
metafunc.parametrize("global_random_seed", random_seeds)


def pytest_configure(config):
# Use matplotlib agg backend during the tests including doctests
try:
Expand All @@ -282,10 +326,6 @@ def pytest_configure(config):
allowed_parallelism = max(allowed_parallelism // int(xdist_worker_count), 1)
threadpool_limits(allowed_parallelism)

# Register global_random_seed plugin if it is not already registered
if not config.pluginmanager.hasplugin("sklearn.tests.random_seed"):
config.pluginmanager.register(random_seed)

if environ.get("SKLEARN_WARNINGS_AS_ERRORS", "0") != "0":
# This seems like the only way to programmatically change the config
# filterwarnings. This was suggested in
Expand Down
85 changes: 0 additions & 85 deletions sklearn/tests/random_seed.py

This file was deleted.

0 comments on commit a4ebe19

Please sign in to comment.