153 lines
5.2 KiB
Python
153 lines
5.2 KiB
Python
![]() |
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
from pandas.errors import NumbaUtilError
|
||
|
import pandas.util._test_decorators as td
|
||
|
|
||
|
from pandas import DataFrame, NamedAgg, option_context
|
||
|
import pandas._testing as tm
|
||
|
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba", "0.46.0")
|
||
|
def test_correct_function_signature():
|
||
|
def incorrect_function(x):
|
||
|
return sum(x) * 2.7
|
||
|
|
||
|
data = DataFrame(
|
||
|
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
|
||
|
columns=["key", "data"],
|
||
|
)
|
||
|
with pytest.raises(NumbaUtilError, match="The first 2"):
|
||
|
data.groupby("key").agg(incorrect_function, engine="numba")
|
||
|
|
||
|
with pytest.raises(NumbaUtilError, match="The first 2"):
|
||
|
data.groupby("key")["data"].agg(incorrect_function, engine="numba")
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba", "0.46.0")
|
||
|
def test_check_nopython_kwargs():
|
||
|
def incorrect_function(x, **kwargs):
|
||
|
return sum(x) * 2.7
|
||
|
|
||
|
data = DataFrame(
|
||
|
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
|
||
|
columns=["key", "data"],
|
||
|
)
|
||
|
with pytest.raises(NumbaUtilError, match="numba does not support"):
|
||
|
data.groupby("key").agg(incorrect_function, engine="numba", a=1)
|
||
|
|
||
|
with pytest.raises(NumbaUtilError, match="numba does not support"):
|
||
|
data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1)
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba", "0.46.0")
|
||
|
@pytest.mark.filterwarnings("ignore:\\nThe keyword argument")
|
||
|
# Filter warnings when parallel=True and the function can't be parallelized by Numba
|
||
|
@pytest.mark.parametrize("jit", [True, False])
|
||
|
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
|
||
|
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython):
|
||
|
def func_numba(values, index):
|
||
|
return np.mean(values) * 2.7
|
||
|
|
||
|
if jit:
|
||
|
# Test accepted jitted functions
|
||
|
import numba
|
||
|
|
||
|
func_numba = numba.jit(func_numba)
|
||
|
|
||
|
data = DataFrame(
|
||
|
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
|
||
|
)
|
||
|
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
|
||
|
grouped = data.groupby(0)
|
||
|
if pandas_obj == "Series":
|
||
|
grouped = grouped[1]
|
||
|
|
||
|
result = grouped.agg(func_numba, engine="numba", engine_kwargs=engine_kwargs)
|
||
|
expected = grouped.agg(lambda x: np.mean(x) * 2.7, engine="cython")
|
||
|
|
||
|
tm.assert_equal(result, expected)
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba", "0.46.0")
|
||
|
@pytest.mark.filterwarnings("ignore:\\nThe keyword argument")
|
||
|
# Filter warnings when parallel=True and the function can't be parallelized by Numba
|
||
|
@pytest.mark.parametrize("jit", [True, False])
|
||
|
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
|
||
|
def test_cache(jit, pandas_obj, nogil, parallel, nopython):
|
||
|
# Test that the functions are cached correctly if we switch functions
|
||
|
def func_1(values, index):
|
||
|
return np.mean(values) - 3.4
|
||
|
|
||
|
def func_2(values, index):
|
||
|
return np.mean(values) * 2.7
|
||
|
|
||
|
if jit:
|
||
|
import numba
|
||
|
|
||
|
func_1 = numba.jit(func_1)
|
||
|
func_2 = numba.jit(func_2)
|
||
|
|
||
|
data = DataFrame(
|
||
|
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
|
||
|
)
|
||
|
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
|
||
|
grouped = data.groupby(0)
|
||
|
if pandas_obj == "Series":
|
||
|
grouped = grouped[1]
|
||
|
|
||
|
result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs)
|
||
|
expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython")
|
||
|
tm.assert_equal(result, expected)
|
||
|
# func_1 should be in the cache now
|
||
|
assert (func_1, "groupby_agg") in NUMBA_FUNC_CACHE
|
||
|
|
||
|
# Add func_2 to the cache
|
||
|
result = grouped.agg(func_2, engine="numba", engine_kwargs=engine_kwargs)
|
||
|
expected = grouped.agg(lambda x: np.mean(x) * 2.7, engine="cython")
|
||
|
tm.assert_equal(result, expected)
|
||
|
assert (func_2, "groupby_agg") in NUMBA_FUNC_CACHE
|
||
|
|
||
|
# Retest func_1 which should use the cache
|
||
|
result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs)
|
||
|
expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython")
|
||
|
tm.assert_equal(result, expected)
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba", "0.46.0")
|
||
|
def test_use_global_config():
|
||
|
def func_1(values, index):
|
||
|
return np.mean(values) - 3.4
|
||
|
|
||
|
data = DataFrame(
|
||
|
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
|
||
|
)
|
||
|
grouped = data.groupby(0)
|
||
|
expected = grouped.agg(func_1, engine="numba")
|
||
|
with option_context("compute.use_numba", True):
|
||
|
result = grouped.agg(func_1, engine=None)
|
||
|
tm.assert_frame_equal(expected, result)
|
||
|
|
||
|
|
||
|
@td.skip_if_no("numba", "0.46.0")
|
||
|
@pytest.mark.parametrize(
|
||
|
"agg_func",
|
||
|
[
|
||
|
["min", "max"],
|
||
|
"min",
|
||
|
{"B": ["min", "max"], "C": "sum"},
|
||
|
NamedAgg(column="B", aggfunc="min"),
|
||
|
],
|
||
|
)
|
||
|
def test_multifunc_notimplimented(agg_func):
|
||
|
data = DataFrame(
|
||
|
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
|
||
|
)
|
||
|
grouped = data.groupby(0)
|
||
|
with pytest.raises(NotImplementedError, match="Numba engine can"):
|
||
|
grouped.agg(agg_func, engine="numba")
|
||
|
|
||
|
with pytest.raises(NotImplementedError, match="Numba engine can"):
|
||
|
grouped[1].agg(agg_func, engine="numba")
|