3RNN/Lib/site-packages/scipy/_lib/tests/test__threadsafety.py

52 lines
1.3 KiB
Python
Raw Permalink Normal View History

2024-05-26 19:49:15 +02:00
import threading
import time
import traceback
from numpy.testing import assert_
from pytest import raises as assert_raises
from scipy._lib._threadsafety import ReentrancyLock, non_reentrant, ReentrancyError
def test_parallel_threads():
# Check that ReentrancyLock serializes work in parallel threads.
#
# The test is not fully deterministic, and may succeed falsely if
# the timings go wrong.
lock = ReentrancyLock("failure")
failflag = [False]
exceptions_raised = []
def worker(k):
try:
with lock:
assert_(not failflag[0])
failflag[0] = True
time.sleep(0.1 * k)
assert_(failflag[0])
failflag[0] = False
except Exception:
exceptions_raised.append(traceback.format_exc(2))
threads = [threading.Thread(target=lambda k=k: worker(k))
for k in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
exceptions_raised = "\n".join(exceptions_raised)
assert_(not exceptions_raised, exceptions_raised)
def test_reentering():
# Check that ReentrancyLock prevents re-entering from the same thread.
@non_reentrant()
def func(x):
return func(x)
assert_raises(ReentrancyError, func, 0)