from scipy.datasets._registry import registry from scipy.datasets._fetchers import data_fetcher from scipy.datasets._utils import _clear_cache from scipy.datasets import ascent, face, electrocardiogram, download_all from numpy.testing import assert_equal, assert_almost_equal import os import pytest try: import pooch except ImportError: raise ImportError("Missing optional dependency 'pooch' required " "for scipy.datasets module. Please use pip or " "conda to install 'pooch'.") data_dir = data_fetcher.path # type: ignore def _has_hash(path, expected_hash): """Check if the provided path has the expected hash.""" if not os.path.exists(path): return False return pooch.file_hash(path) == expected_hash class TestDatasets: @pytest.fixture(scope='module', autouse=True) def test_download_all(self): # This fixture requires INTERNET CONNECTION # test_setup phase download_all() yield def test_existence_all(self): assert len(os.listdir(data_dir)) >= len(registry) def test_ascent(self): assert_equal(ascent().shape, (512, 512)) # hash check assert _has_hash(os.path.join(data_dir, "ascent.dat"), registry["ascent.dat"]) def test_face(self): assert_equal(face().shape, (768, 1024, 3)) # hash check assert _has_hash(os.path.join(data_dir, "face.dat"), registry["face.dat"]) def test_electrocardiogram(self): # Test shape, dtype and stats of signal ecg = electrocardiogram() assert_equal(ecg.dtype, float) assert_equal(ecg.shape, (108000,)) assert_almost_equal(ecg.mean(), -0.16510875) assert_almost_equal(ecg.std(), 0.5992473991177294) # hash check assert _has_hash(os.path.join(data_dir, "ecg.dat"), registry["ecg.dat"]) def test_clear_cache(tmp_path): # Note: `tmp_path` is a pytest fixture, it handles cleanup dummy_basepath = tmp_path / "dummy_cache_dir" dummy_basepath.mkdir() # Create three dummy dataset files for dummy dataset methods dummy_method_map = {} for i in range(4): dummy_method_map[f"data{i}"] = [f"data{i}.dat"] data_filepath = dummy_basepath / f"data{i}.dat" data_filepath.write_text("") # clear files associated to single dataset method data0 # also test callable argument instead of list of callables def data0(): pass _clear_cache(datasets=data0, cache_dir=dummy_basepath, method_map=dummy_method_map) assert not os.path.exists(dummy_basepath/"data0.dat") # clear files associated to multiple dataset methods "data3" and "data4" def data1(): pass def data2(): pass _clear_cache(datasets=[data1, data2], cache_dir=dummy_basepath, method_map=dummy_method_map) assert not os.path.exists(dummy_basepath/"data1.dat") assert not os.path.exists(dummy_basepath/"data2.dat") # clear multiple dataset files "data3_0.dat" and "data3_1.dat" # associated with dataset method "data3" def data4(): pass # create files (dummy_basepath / "data4_0.dat").write_text("") (dummy_basepath / "data4_1.dat").write_text("") dummy_method_map["data4"] = ["data4_0.dat", "data4_1.dat"] _clear_cache(datasets=[data4], cache_dir=dummy_basepath, method_map=dummy_method_map) assert not os.path.exists(dummy_basepath/"data4_0.dat") assert not os.path.exists(dummy_basepath/"data4_1.dat") # wrong dataset method should raise ValueError since it # doesn't exist in the dummy_method_map def data5(): pass with pytest.raises(ValueError): _clear_cache(datasets=[data5], cache_dir=dummy_basepath, method_map=dummy_method_map) # remove all dataset cache _clear_cache(datasets=None, cache_dir=dummy_basepath) assert not os.path.exists(dummy_basepath)