import time from concurrent.futures import ThreadPoolExecutor import pytest from sklearn import get_config, set_config, config_context from sklearn.utils.parallel import delayed, Parallel def test_config_context(): assert get_config() == { "assume_finite": False, "working_memory": 1024, "print_changed_only": True, "display": "diagram", "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, "transform_output": "default", } # Not using as a context manager affects nothing config_context(assume_finite=True) assert get_config()["assume_finite"] is False with config_context(assume_finite=True): assert get_config() == { "assume_finite": True, "working_memory": 1024, "print_changed_only": True, "display": "diagram", "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, "transform_output": "default", } assert get_config()["assume_finite"] is False with config_context(assume_finite=True): with config_context(assume_finite=None): assert get_config()["assume_finite"] is True assert get_config()["assume_finite"] is True with config_context(assume_finite=False): assert get_config()["assume_finite"] is False with config_context(assume_finite=None): assert get_config()["assume_finite"] is False # global setting will not be retained outside of context that # did not modify this setting set_config(assume_finite=True) assert get_config()["assume_finite"] is True assert get_config()["assume_finite"] is False assert get_config()["assume_finite"] is True assert get_config() == { "assume_finite": False, "working_memory": 1024, "print_changed_only": True, "display": "diagram", "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, "transform_output": "default", } # No positional arguments with pytest.raises(TypeError): config_context(True) # No unknown arguments with pytest.raises(TypeError): config_context(do_something_else=True).__enter__() def test_config_context_exception(): assert get_config()["assume_finite"] is False try: with config_context(assume_finite=True): assert get_config()["assume_finite"] is True raise ValueError() except ValueError: pass assert get_config()["assume_finite"] is False def test_set_config(): assert get_config()["assume_finite"] is False set_config(assume_finite=None) assert get_config()["assume_finite"] is False set_config(assume_finite=True) assert get_config()["assume_finite"] is True set_config(assume_finite=None) assert get_config()["assume_finite"] is True set_config(assume_finite=False) assert get_config()["assume_finite"] is False # No unknown arguments with pytest.raises(TypeError): set_config(do_something_else=True) def set_assume_finite(assume_finite, sleep_duration): """Return the value of assume_finite after waiting `sleep_duration`.""" with config_context(assume_finite=assume_finite): time.sleep(sleep_duration) return get_config()["assume_finite"] @pytest.mark.parametrize("backend", ["loky", "multiprocessing", "threading"]) def test_config_threadsafe_joblib(backend): """Test that the global config is threadsafe with all joblib backends. Two jobs are spawned and sets assume_finite to two different values. When the job with a duration 0.1s completes, the assume_finite value should be the same as the value passed to the function. In other words, it is not influenced by the other job setting assume_finite to True. """ assume_finites = [False, True, False, True] sleep_durations = [0.1, 0.2, 0.1, 0.2] items = Parallel(backend=backend, n_jobs=2)( delayed(set_assume_finite)(assume_finite, sleep_dur) for assume_finite, sleep_dur in zip(assume_finites, sleep_durations) ) assert items == [False, True, False, True] def test_config_threadsafe(): """Uses threads directly to test that the global config does not change between threads. Same test as `test_config_threadsafe_joblib` but with `ThreadPoolExecutor`.""" assume_finites = [False, True, False, True] sleep_durations = [0.1, 0.2, 0.1, 0.2] with ThreadPoolExecutor(max_workers=2) as e: items = [ output for output in e.map(set_assume_finite, assume_finites, sleep_durations) ] assert items == [False, True, False, True]