3RNN/Lib/site-packages/sklearn/tests/random_seed.py
2024-05-26 19:49:15 +02:00

86 lines
3.2 KiB
Python

"""global_random_seed fixture
The goal of this fixture is to prevent tests that use it to be sensitive
to a specific seed value while still being deterministic by default.
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
variable for insrtuctions on how to use this fixture.
https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed
"""
from os import environ
from random import Random
import pytest
# Passes the main worker's random seeds to workers
class XDistHooks:
def pytest_configure_node(self, node) -> None:
random_seeds = node.config.getoption("random_seeds")
node.workerinput["random_seeds"] = random_seeds
def pytest_configure(config):
if config.pluginmanager.hasplugin("xdist"):
config.pluginmanager.register(XDistHooks())
RANDOM_SEED_RANGE = list(range(100)) # All seeds in [0, 99] should be valid.
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED")
if hasattr(config, "workerinput") and "random_seeds" in config.workerinput:
# Set worker random seed from seed generated from main process
random_seeds = config.workerinput["random_seeds"]
elif random_seed_var is None:
# This is the way.
random_seeds = [42]
elif random_seed_var == "any":
# Pick-up one seed at random in the range of admissible random seeds.
random_seeds = [Random().choice(RANDOM_SEED_RANGE)]
elif random_seed_var == "all":
random_seeds = RANDOM_SEED_RANGE
else:
if "-" in random_seed_var:
start, stop = random_seed_var.split("-")
random_seeds = list(range(int(start), int(stop) + 1))
else:
random_seeds = [int(random_seed_var)]
if min(random_seeds) < 0 or max(random_seeds) > 99:
raise ValueError(
"The value(s) of the environment variable "
"SKLEARN_TESTS_GLOBAL_RANDOM_SEED must be in the range [0, 99] "
f"(or 'any' or 'all'), got: {random_seed_var}"
)
config.option.random_seeds = random_seeds
class GlobalRandomSeedPlugin:
@pytest.fixture(params=random_seeds)
def global_random_seed(self, request):
"""Fixture to ask for a random yet controllable random seed.
All tests that use this fixture accept the contract that they should
deterministically pass for any seed value from 0 to 99 included.
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
variable for insrtuctions on how to use this fixture.
https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed
"""
yield request.param
config.pluginmanager.register(GlobalRandomSeedPlugin())
def pytest_report_header(config):
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED")
if random_seed_var == "any":
return [
"To reproduce this test run, set the following environment variable:",
f' SKLEARN_TESTS_GLOBAL_RANDOM_SEED="{config.option.random_seeds[0]}"',
(
"See: https://scikit-learn.org/dev/computing/parallelism.html"
"#sklearn-tests-global-random-seed"
),
]