96 lines
3.4 KiB
Python
96 lines
3.4 KiB
Python
|
# Pytest customization
|
||
|
import os
|
||
|
import pytest
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
import numpy.testing as npt
|
||
|
from scipy._lib._fpumode import get_fpu_mode
|
||
|
from scipy._lib._testutils import FPUModeChangeWarning
|
||
|
from scipy._lib import _pep440
|
||
|
|
||
|
|
||
|
def pytest_configure(config):
|
||
|
config.addinivalue_line("markers",
|
||
|
"slow: Tests that are very slow.")
|
||
|
config.addinivalue_line("markers",
|
||
|
"xslow: mark test as extremely slow (not run unless explicitly requested)")
|
||
|
config.addinivalue_line("markers",
|
||
|
"xfail_on_32bit: mark test as failing on 32-bit platforms")
|
||
|
try:
|
||
|
import pytest_timeout # noqa:F401
|
||
|
except Exception:
|
||
|
config.addinivalue_line(
|
||
|
"markers", 'timeout: mark a test for a non-default timeout')
|
||
|
|
||
|
|
||
|
def _get_mark(item, name):
|
||
|
if _pep440.parse(pytest.__version__) >= _pep440.Version("3.6.0"):
|
||
|
mark = item.get_closest_marker(name)
|
||
|
else:
|
||
|
mark = item.get_marker(name)
|
||
|
return mark
|
||
|
|
||
|
|
||
|
def pytest_runtest_setup(item):
|
||
|
mark = _get_mark(item, "xslow")
|
||
|
if mark is not None:
|
||
|
try:
|
||
|
v = int(os.environ.get('SCIPY_XSLOW', '0'))
|
||
|
except ValueError:
|
||
|
v = False
|
||
|
if not v:
|
||
|
pytest.skip("very slow test; set environment variable SCIPY_XSLOW=1 to run it")
|
||
|
mark = _get_mark(item, 'xfail_on_32bit')
|
||
|
if mark is not None and np.intp(0).itemsize < 8:
|
||
|
pytest.xfail('Fails on our 32-bit test platform(s): %s' % (mark.args[0],))
|
||
|
|
||
|
# Older versions of threadpoolctl have an issue that may lead to this
|
||
|
# warning being emitted, see gh-14441
|
||
|
with npt.suppress_warnings() as sup:
|
||
|
sup.filter(pytest.PytestUnraisableExceptionWarning)
|
||
|
|
||
|
try:
|
||
|
from threadpoolctl import threadpool_limits
|
||
|
|
||
|
HAS_THREADPOOLCTL = True
|
||
|
except Exception: # observed in gh-14441: (ImportError, AttributeError)
|
||
|
# Optional dependency only. All exceptions are caught, for robustness
|
||
|
HAS_THREADPOOLCTL = False
|
||
|
|
||
|
if HAS_THREADPOOLCTL:
|
||
|
# Set the number of openmp threads based on the number of workers
|
||
|
# xdist is using to prevent oversubscription. Simplified version of what
|
||
|
# sklearn does (it can rely on threadpoolctl and its builtin OpenMP helper
|
||
|
# functions)
|
||
|
try:
|
||
|
xdist_worker_count = int(os.environ['PYTEST_XDIST_WORKER_COUNT'])
|
||
|
except KeyError:
|
||
|
# raises when pytest-xdist is not installed
|
||
|
return
|
||
|
|
||
|
if not os.getenv('OMP_NUM_THREADS'):
|
||
|
max_openmp_threads = os.cpu_count() // 2 # use nr of physical cores
|
||
|
threads_per_worker = max(max_openmp_threads // xdist_worker_count, 1)
|
||
|
try:
|
||
|
threadpool_limits(threads_per_worker, user_api='blas')
|
||
|
except Exception:
|
||
|
# May raise AttributeError for older versions of OpenBLAS.
|
||
|
# Catch any error for robustness.
|
||
|
return
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="function", autouse=True)
|
||
|
def check_fpu_mode(request):
|
||
|
"""
|
||
|
Check FPU mode was not changed during the test.
|
||
|
"""
|
||
|
old_mode = get_fpu_mode()
|
||
|
yield
|
||
|
new_mode = get_fpu_mode()
|
||
|
|
||
|
if old_mode != new_mode:
|
||
|
warnings.warn("FPU mode changed from {0:#x} to {1:#x} during "
|
||
|
"the test".format(old_mode, new_mode),
|
||
|
category=FPUModeChangeWarning, stacklevel=0)
|