Improve tests by using global_random_seed fixture to make them less seed-sensitive #22827
Description
Context: the new global_random_seed
fixture
#22749 introduces a new global_random_seed
fixture to make it possible to run the same test with any seed between 0 and 99 included. By default, when SKLEARN_TESTS_GLOBAL_RANDOM_SEED
is not set, this fixture is deterministically returning 42 to keep test runs deterministic by default and avoid any unnecessary disruption. However different CI builds set this seed to other arbitrary values (still deterministic) and nightly schedule builds on Azure now use SKLEARN_TESTS_GLOBAL_RANDOM_SEED="any"
to progressively explore any seed on the 0-99 range.
Motivation
The aim of this new fixture is to make sure that we avoid writing tests that artificially depend on a specific value of the random seed and therefore hiding a real mathematical problem in our code unknowingly (see e.g. #21701 (comment)). At the same time we still want to keep the test deterministic and independent of the execution order by default to avoid introducing unnecessary maintenance overhead.
In addition to making the tests insensitive, randomizing those tests with different seeds has the side benefit of making the assertions of those tests robust to small numerical variations that could otherwise stem from other sources such as platform-specific / dependency-specific numerical rounding variations that we do not cover in our existing CI infrastructure.
More details about the fixture in the online dev doc for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
env variable:
https://scikit-learn.org/dev/computing/parallelism.html#environment-variables
Guidelines to convert existing tests
-
We probably do not need to convert all scikit-learn tests to use this fixture. We should instead focus our efforts on tests that actually check for important mathematical properties of our estimators or model evaluation tools. For instance, there is no need to check for the seed-insensitivity of tests that checks for the exception messages raised when passing invalid inputs.
-
To avoid having to review huge PRs that impact many files at once and can lead to conflicts, let's open PRs that edit at most one test file at a time. For instance use a title such as:
TST use global_random_seed in sklearn/_loss/tests/test_glm_distribution.py
-
Please reference
#22827
in the description of the PR and put the full filename of the test file you edit in the title of the PR. -
To convert an existing test with a fixed seed, the general pattern is to rewrite a function such as:
def test_some_function():
rng = np.random.RandomState(0)
...
to:
def test_some_function(global_random_seed):
rng = np.random.RandomState(global_random_seed)
...
and then check that the test function is actually seed-insensitive by running with all seeds between 0 and 99 locally (can be slow! only run for one specific test at a time!):
SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all" pytest sklearn/some_module/test/test_some_module.py -k test_some_function
If this is not the case, the test will probably need to be reworked to find a more stable to way to check the interesting mathematical properties.
-
if the failing assertions are related to the generalization performance of a model, maybe the training set size should be slightly bigger (while keeping the test runtime as fast as possible), or with fewer noisy features or the training should be done with stronger regularization. Or more simply we can relax the tolerance threshold while ensuring it does not become trivial (e.g. by comparing to a trivial baseline);
-
if the failing assertions depend on some regularities of a synthetically generated dataset, making decreasing the noise level of the datasets;
-
some tests might also fail when encountering data that trigger edge cases such as (near-)tied distances between datapoints that make the outcome of computation unstable. Changing the data generation code to significantly decrease the likelihood of those edge case (e.g. by adding more noise to the input features) can help in those cases.
-
Note: in most cases, tweaking the tolerances of the assertions is not the appropriate way to make the tests pass. The first thing to do is try to understand what the test is checking, if the test is correct, if the expectations of the test are realistic. Then if the test seems correct and should pass for all random seed but doesn't, investigate if the estimator or function is bugged. As a last resort, tolerances can be loosened if the test is considered valid but aims to check a statistical property that is highly sensitive to the random seed.
In some cases, it might be very hard to write a seed-insensitive test that tolerate all seeds between 0 and 99 while still running in less than 1s. In those (hopefully rare) cases, I think it's fine to reduce the range of admissible seeds with the following pattern:
def test_some_function(global_random_seed):
# Making this test seed-insensitive for the 0-99 range would
# be too costly. Restricting to the 0-9 range is necessary to
# use small enough datasets that avoid increasing the run time
# too much.
rng = np.random.RandomState(global_random_seed % 10)
...
- Run the CI for tests that take a
global_random_seed
by pushing a commit message with the following structure:
<title> [all random seeds]
<test_name_1>
<test_name_2>
...
Note, running git commit --allow-empty
allows you to have a commit message without any changes.
See the following issue for more details on why testing on the CI is necessary:
List of test modules to upgrade
find sklearn -name "test_*.py"
- sklearn/_loss/tests/test_glm_distribution.py
- sklearn/_loss/tests/test_link.py
- sklearn/_loss/tests/test_loss.py TST ensure that sklearn/_loss/tests/test_loss.py is seed insensitive #22847
- sklearn/cluster/tests/test_affinity_propagation.py
- sklearn/cluster/tests/test_bicluster.py
- sklearn/cluster/tests/test_birch.py
- sklearn/cluster/tests/test_dbscan.py
- sklearn/cluster/tests/test_feature_agglomeration.py TST use global_random_seed in sklearn/cluster/tests/test_feature_agglomeration.py #23700
- sklearn/cluster/tests/test_hierarchical.py
- sklearn/cluster/tests/test_k_means.py
- sklearn/cluster/tests/test_mean_shift.py TST Use global_random_seed in sklearn/cluster/tests/test_mean_shift.py #30517
- sklearn/cluster/tests/test_optics.py
- sklearn/cluster/tests/test_spectral.py
- sklearn/compose/tests/test_column_transformer.py
- sklearn/compose/tests/test_target.py
- sklearn/covariance/tests/test_covariance.py
- sklearn/covariance/tests/test_elliptic_envelope.py
- sklearn/covariance/tests/test_graphical_lasso.py
- sklearn/covariance/tests/test_robust_covariance.py
- sklearn/cross_decomposition/tests/test_pls.py
- sklearn/datasets/tests/test_20news.py
- sklearn/datasets/tests/test_base.py
- sklearn/datasets/tests/test_california_housing.py
- sklearn/datasets/tests/test_common.py
- sklearn/datasets/tests/test_covtype.py
- sklearn/datasets/tests/test_kddcup99.py
- sklearn/datasets/tests/test_lfw.py
- sklearn/datasets/tests/test_olivetti_faces.py
- sklearn/datasets/tests/test_openml.py
- sklearn/datasets/tests/test_rcv1.py
- sklearn/datasets/tests/test_samples_generator.py
- sklearn/datasets/tests/test_svmlight_format.py
- sklearn/decomposition/tests/test_dict_learning.py
- sklearn/decomposition/tests/test_factor_analysis.py
- sklearn/decomposition/tests/test_fastica.py
- sklearn/decomposition/tests/test_incremental_pca.py
- sklearn/decomposition/tests/test_kernel_pca.py TST use global_random_seed in sklearn/decomposition/tests/test_kernel_pca.py #30518
- sklearn/decomposition/tests/test_nmf.py
- sklearn/decomposition/tests/test_online_lda.py
- sklearn/decomposition/tests/test_pca.py
- sklearn/decomposition/tests/test_sparse_pca.py
- sklearn/decomposition/tests/test_truncated_svd.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_bitset.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_histogram.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py
- sklearn/ensemble/_hist_gradient_boosting/tests/test_warm_start.py
- sklearn/ensemble/tests/test_bagging.py
- sklearn/ensemble/tests/test_base.py
- sklearn/ensemble/tests/test_common.py
- sklearn/ensemble/tests/test_forest.py
- sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py
- sklearn/ensemble/tests/test_gradient_boosting.py
- sklearn/ensemble/tests/test_iforest.py TST use global_random_seed in sklearn/ensemble/tests/test_iforest.py #22901
- sklearn/ensemble/tests/test_stacking.py
- sklearn/ensemble/tests/test_voting.py
- sklearn/ensemble/tests/test_weight_boosting.py
- sklearn/experimental/tests/test_enable_hist_gradient_boosting.py
- sklearn/experimental/tests/test_enable_iterative_imputer.py
- sklearn/experimental/tests/test_enable_successive_halving.py
- sklearn/feature_extraction/tests/test_dict_vectorizer.py
- sklearn/feature_extraction/tests/test_feature_hasher.py
- sklearn/feature_extraction/tests/test_image.py
- sklearn/feature_extraction/tests/test_text.py
- sklearn/feature_selection/tests/test_base.py
- sklearn/feature_selection/tests/test_chi2.py
- sklearn/feature_selection/tests/test_feature_select.py
- sklearn/feature_selection/tests/test_from_model.py
- sklearn/feature_selection/tests/test_mutual_info.py
- sklearn/feature_selection/tests/test_rfe.py
- sklearn/feature_selection/tests/test_sequential.py
- sklearn/feature_selection/tests/test_variance_threshold.py
- sklearn/gaussian_process/tests/test_gpc.py
- sklearn/gaussian_process/tests/test_gpr.py
- sklearn/gaussian_process/tests/test_kernels.py
- sklearn/impute/tests/test_base.py
- sklearn/impute/tests/test_common.py
- sklearn/impute/tests/test_impute.py
- sklearn/impute/tests/test_knn.py
- sklearn/inspection/_plot/tests/test_plot_partial_dependence.py
- sklearn/inspection/tests/test_partial_dependence.py
- sklearn/inspection/tests/test_permutation_importance.py
- sklearn/linear_model/_glm/tests/test_glm.py
- sklearn/linear_model/_glm/tests/test_link.py
- sklearn/linear_model/tests/test_base.py
- sklearn/linear_model/tests/test_bayes.py
- sklearn/linear_model/tests/test_common.py
- sklearn/linear_model/tests/test_coordinate_descent.py
- sklearn/linear_model/tests/test_huber.py
- sklearn/linear_model/tests/test_least_angle.py
- sklearn/linear_model/tests/test_linear_loss.py
- sklearn/linear_model/tests/test_logistic.py
- sklearn/linear_model/tests/test_omp.py
- sklearn/linear_model/tests/test_passive_aggressive.py
- sklearn/linear_model/tests/test_perceptron.py
- sklearn/linear_model/tests/test_quantile.py
- sklearn/linear_model/tests/test_ransac.py
- sklearn/linear_model/tests/test_ridge.py
- sklearn/linear_model/tests/test_sag.py
- sklearn/linear_model/tests/test_sgd.py
- sklearn/linear_model/tests/test_sparse_coordinate_descent.py
- sklearn/linear_model/tests/test_theil_sen.py
- sklearn/manifold/tests/test_isomap.py
- sklearn/manifold/tests/test_locally_linear.py
- sklearn/manifold/tests/test_mds.py
- sklearn/manifold/tests/test_spectral_embedding.py
- sklearn/manifold/tests/test_t_sne.py
- sklearn/metrics/_plot/tests/test_base.py
- sklearn/metrics/_plot/tests/test_common_curve_display.py
- sklearn/metrics/_plot/tests/test_confusion_matrix_display.py
- sklearn/metrics/_plot/tests/test_det_curve_display.py
- sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py
- sklearn/metrics/_plot/tests/test_plot_curve_common.py
- sklearn/metrics/_plot/tests/test_plot_det_curve.py
- sklearn/metrics/_plot/tests/test_plot_precision_recall.py
- sklearn/metrics/_plot/tests/test_plot_roc_curve.py
- sklearn/metrics/_plot/tests/test_precision_recall_display.py
- sklearn/metrics/_plot/tests/test_roc_curve_display.py
- sklearn/metrics/cluster/tests/test_bicluster.py
- sklearn/metrics/cluster/tests/test_common.py
- sklearn/metrics/cluster/tests/test_supervised.py
- sklearn/metrics/cluster/tests/test_unsupervised.py
- sklearn/metrics/tests/test_classification.py
- sklearn/metrics/tests/test_common.py
- sklearn/metrics/tests/test_dist_metrics.py
- sklearn/metrics/tests/test_pairwise_distances_reduction.py TST Ensure that
sklearn/metrics/tests/test_pairwise_distances_reduction.py
is seed insensitive #22862 - sklearn/metrics/tests/test_pairwise.py
- sklearn/metrics/tests/test_ranking.py
- sklearn/metrics/tests/test_regression.py
- sklearn/metrics/tests/test_score_objects.py
- sklearn/mixture/tests/test_bayesian_mixture.py
- sklearn/mixture/tests/test_gaussian_mixture.py
- sklearn/mixture/tests/test_mixture.py
- sklearn/model_selection/tests/test_search.py
- sklearn/model_selection/tests/test_split.py
- sklearn/model_selection/tests/test_successive_halving.py
- sklearn/model_selection/tests/test_validation.py
- sklearn/neighbors/tests/test_ball_tree.py
- sklearn/neighbors/tests/test_graph.py
- sklearn/neighbors/tests/test_kd_tree.py
- sklearn/neighbors/tests/test_kde.py
- sklearn/neighbors/tests/test_lof.py
- sklearn/neighbors/tests/test_nca.py
- sklearn/neighbors/tests/test_nearest_centroid.py
- sklearn/neighbors/tests/test_neighbors_pipeline.py
- sklearn/neighbors/tests/test_neighbors_tree.py
- sklearn/neighbors/tests/test_neighbors.py
- sklearn/neighbors/tests/test_quad_tree.py
- sklearn/neural_network/tests/test_base.py
- sklearn/neural_network/tests/test_mlp.py
- sklearn/neural_network/tests/test_rbm.py
- sklearn/neural_network/tests/test_stochastic_optimizers.py
- sklearn/preprocessing/tests/test_common.py
- sklearn/preprocessing/tests/test_data.py
- sklearn/preprocessing/tests/test_discretization.py
- sklearn/preprocessing/tests/test_encoders.py
- sklearn/preprocessing/tests/test_function_transformer.py
- sklearn/preprocessing/tests/test_label.py
- sklearn/preprocessing/tests/test_polynomial.py
- sklearn/semi_supervised/tests/test_label_propagation.py
- sklearn/semi_supervised/tests/test_self_training.py
- sklearn/svm/tests/test_bounds.py
- sklearn/svm/tests/test_sparse.py
- sklearn/svm/tests/test_svm.py
- sklearn/tests/test_base.py
- sklearn/tests/test_build.py
- sklearn/tests/test_calibration.py
- sklearn/tests/test_check_build.py
- sklearn/tests/test_common.py
- sklearn/tests/test_config.py
- sklearn/tests/test_discriminant_analysis.py
- sklearn/tests/test_docstring_parameters.py
- sklearn/tests/test_docstrings.py
- sklearn/tests/test_dummy.py
- sklearn/tests/test_init.py
- sklearn/tests/test_isotonic.py
- sklearn/tests/test_kernel_approximation.py
- sklearn/tests/test_kernel_ridge.py
- sklearn/tests/test_metaestimators.py
- sklearn/tests/test_min_dependencies_readme.py
- sklearn/tests/test_multiclass.py
- sklearn/tests/test_multioutput.py
- sklearn/tests/test_naive_bayes.py
- sklearn/tests/test_pipeline.py
- sklearn/tests/test_random_projection.py
- sklearn/tree/tests/test_export.py
- sklearn/tree/tests/test_reingold_tilford.py
- sklearn/tree/tests/test_tree.py
- sklearn/utils/tests/test_arpack.py
- sklearn/utils/tests/test_arrayfuncs.py
- sklearn/utils/tests/test_class_weight.py
- sklearn/utils/tests/test_cython_blas.py
- sklearn/utils/tests/test_cython_templating.py
- sklearn/utils/tests/test_deprecation.py
- sklearn/utils/tests/test_encode.py
- sklearn/utils/tests/test_estimator_checks.py
- sklearn/utils/tests/test_estimator_html_repr.py
- sklearn/utils/tests/test_extmath.py
- sklearn/utils/tests/test_fast_dict.py
- sklearn/utils/tests/test_fixes.py
- sklearn/utils/tests/test_graph.py
- sklearn/utils/tests/test_metaestimators.py
- sklearn/utils/tests/test_mocking.py
- sklearn/utils/tests/test_multiclass.py
- sklearn/utils/tests/test_murmurhash.py
- sklearn/utils/tests/test_optimize.py TST use global_random_seed in sklearn/utils/tests/test_optimize.py #30112
- sklearn/utils/tests/test_parallel.py
- sklearn/utils/tests/test_pprint.py
- sklearn/utils/tests/test_random.py
- sklearn/utils/tests/test_readonly_wrapper.py
- sklearn/utils/tests/test_seq_dataset.py
- sklearn/utils/tests/test_shortest_path.py
- sklearn/utils/tests/test_show_versions.py
- sklearn/utils/tests/test_sparsefuncs.py
- sklearn/utils/tests/test_stats.py
- sklearn/utils/tests/test_tags.py
- sklearn/utils/tests/test_testing.py
- sklearn/utils/tests/test_utils.py
- sklearn/utils/tests/test_validation.py
- sklearn/utils/tests/test_weight_vector.py
Note that some of those files might not have any test to update.