90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
|
import pytest
|
||
|
|
||
|
from fsspec.callbacks import Callback, TqdmCallback
|
||
|
|
||
|
|
||
|
def test_callbacks():
|
||
|
empty_callback = Callback()
|
||
|
assert empty_callback.call("something", somearg=None) is None
|
||
|
|
||
|
hooks = {"something": lambda *_, arg=None: arg + 2}
|
||
|
simple_callback = Callback(hooks=hooks)
|
||
|
assert simple_callback.call("something", arg=2) == 4
|
||
|
|
||
|
hooks = {"something": lambda *_, arg1=None, arg2=None: arg1 + arg2}
|
||
|
multi_arg_callback = Callback(hooks=hooks)
|
||
|
assert multi_arg_callback.call("something", arg1=2, arg2=2) == 4
|
||
|
|
||
|
|
||
|
def test_callbacks_as_callback():
|
||
|
empty_callback = Callback.as_callback(None)
|
||
|
assert empty_callback.call("something", arg="somearg") is None
|
||
|
assert Callback.as_callback(None) is Callback.as_callback(None)
|
||
|
|
||
|
hooks = {"something": lambda *_, arg=None: arg + 2}
|
||
|
real_callback = Callback.as_callback(Callback(hooks=hooks))
|
||
|
assert real_callback.call("something", arg=2) == 4
|
||
|
|
||
|
|
||
|
def test_callbacks_as_context_manager(mocker):
|
||
|
spy_close = mocker.spy(Callback, "close")
|
||
|
|
||
|
with Callback() as cb:
|
||
|
assert isinstance(cb, Callback)
|
||
|
|
||
|
spy_close.assert_called_once()
|
||
|
|
||
|
|
||
|
def test_callbacks_branched():
|
||
|
callback = Callback()
|
||
|
|
||
|
branch = callback.branched("path_1", "path_2")
|
||
|
|
||
|
assert branch is not callback
|
||
|
assert isinstance(branch, Callback)
|
||
|
|
||
|
|
||
|
@pytest.mark.asyncio
|
||
|
async def test_callbacks_branch_coro(mocker):
|
||
|
async_fn = mocker.AsyncMock(return_value=10)
|
||
|
callback = Callback()
|
||
|
wrapped_fn = callback.branch_coro(async_fn)
|
||
|
spy = mocker.spy(callback, "branched")
|
||
|
|
||
|
assert await wrapped_fn("path_1", "path_2", key="value") == 10
|
||
|
|
||
|
spy.assert_called_once_with("path_1", "path_2", key="value")
|
||
|
async_fn.assert_called_once_with(
|
||
|
"path_1", "path_2", callback=spy.spy_return, key="value"
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_callbacks_wrap():
|
||
|
events = []
|
||
|
|
||
|
class TestCallback(Callback):
|
||
|
def relative_update(self, inc=1):
|
||
|
events.append(inc)
|
||
|
|
||
|
callback = TestCallback()
|
||
|
for _ in callback.wrap(range(10)):
|
||
|
...
|
||
|
|
||
|
assert events == [1] * 10
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"desc": "A custom desc"}])
|
||
|
def test_tqdm_callback(tqdm_kwargs, mocker):
|
||
|
pytest.importorskip("tqdm")
|
||
|
callback = TqdmCallback(tqdm_kwargs=tqdm_kwargs)
|
||
|
mocker.patch.object(callback, "_tqdm_cls")
|
||
|
callback.set_size(10)
|
||
|
for _ in callback.wrap(range(10)):
|
||
|
...
|
||
|
|
||
|
assert callback.tqdm.update.call_count == 11
|
||
|
if not tqdm_kwargs:
|
||
|
callback._tqdm_cls.assert_called_with(total=10)
|
||
|
else:
|
||
|
callback._tqdm_cls.assert_called_with(total=10, **tqdm_kwargs)
|