import pytest from pandas.errors import NumbaUtilError import pandas.util._test_decorators as td from pandas import ( DataFrame, Series, option_context, ) import pandas._testing as tm @td.skip_if_no("numba") def test_correct_function_signature(): def incorrect_function(x): return x + 1 data = DataFrame( {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=["key", "data"], ) with pytest.raises(NumbaUtilError, match="The first 2"): data.groupby("key").transform(incorrect_function, engine="numba") with pytest.raises(NumbaUtilError, match="The first 2"): data.groupby("key")["data"].transform(incorrect_function, engine="numba") @td.skip_if_no("numba") def test_check_nopython_kwargs(): def incorrect_function(values, index): return values + 1 data = DataFrame( {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=["key", "data"], ) with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key").transform(incorrect_function, engine="numba", a=1) with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1) @td.skip_if_no("numba") @pytest.mark.filterwarnings("ignore") # Filter warnings when parallel=True and the function can't be parallelized by Numba @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) @pytest.mark.parametrize("as_index", [True, False]) def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index): def func(values, index): return values + 1 if jit: # Test accepted jitted functions import numba func = numba.jit(func) data = DataFrame( {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] ) engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} grouped = data.groupby(0, as_index=as_index) if pandas_obj == "Series": grouped = grouped[1] result = grouped.transform(func, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_equal(result, expected) @td.skip_if_no("numba") @pytest.mark.filterwarnings("ignore") # Filter warnings when parallel=True and the function can't be parallelized by Numba @pytest.mark.parametrize("jit", [True, False]) @pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) def test_cache(jit, pandas_obj, nogil, parallel, nopython): # Test that the functions are cached correctly if we switch functions def func_1(values, index): return values + 1 def func_2(values, index): return values * 5 if jit: import numba func_1 = numba.jit(func_1) func_2 = numba.jit(func_2) data = DataFrame( {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] ) engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} grouped = data.groupby(0) if pandas_obj == "Series": grouped = grouped[1] result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_equal(result, expected) result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x * 5, engine="cython") tm.assert_equal(result, expected) # Retest func_1 which should use the cache result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_equal(result, expected) @td.skip_if_no("numba") def test_use_global_config(): def func_1(values, index): return values + 1 data = DataFrame( {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] ) grouped = data.groupby(0) expected = grouped.transform(func_1, engine="numba") with option_context("compute.use_numba", True): result = grouped.transform(func_1, engine=None) tm.assert_frame_equal(expected, result) @td.skip_if_no("numba") @pytest.mark.parametrize( "agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}] ) def test_multifunc_notimplimented(agg_func): data = DataFrame( {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] ) grouped = data.groupby(0) with pytest.raises(NotImplementedError, match="Numba engine can"): grouped.transform(agg_func, engine="numba") with pytest.raises(NotImplementedError, match="Numba engine can"): grouped[1].transform(agg_func, engine="numba") @td.skip_if_no("numba") def test_args_not_cached(): # GH 41647 def sum_last(values, index, n): return values[-n:].sum() df = DataFrame({"id": [0, 0, 1, 1], "x": [1, 1, 1, 1]}) grouped_x = df.groupby("id")["x"] result = grouped_x.transform(sum_last, 1, engine="numba") expected = Series([1.0] * 4, name="x") tm.assert_series_equal(result, expected) result = grouped_x.transform(sum_last, 2, engine="numba") expected = Series([2.0] * 4, name="x") tm.assert_series_equal(result, expected) @td.skip_if_no("numba") def test_index_data_correctly_passed(): # GH 43133 def f(values, index): return index - 1 df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3]) result = df.groupby("group").transform(f, engine="numba") expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3]) tm.assert_frame_equal(result, expected) @td.skip_if_no("numba") def test_engine_kwargs_not_cached(): # If the user passes a different set of engine_kwargs don't return the same # jitted function nogil = True parallel = False nopython = True def func_kwargs(values, index): return nogil + parallel + nopython engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} df = DataFrame({"value": [0, 0, 0]}) result = df.groupby(level=0).transform( func_kwargs, engine="numba", engine_kwargs=engine_kwargs ) expected = DataFrame({"value": [2.0, 2.0, 2.0]}) tm.assert_frame_equal(result, expected) nogil = False engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} result = df.groupby(level=0).transform( func_kwargs, engine="numba", engine_kwargs=engine_kwargs ) expected = DataFrame({"value": [1.0, 1.0, 1.0]}) tm.assert_frame_equal(result, expected) @td.skip_if_no("numba") @pytest.mark.filterwarnings("ignore") def test_multiindex_one_key(nogil, parallel, nopython): def numba_func(values, index): return 1 df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} result = df.groupby("A").transform( numba_func, engine="numba", engine_kwargs=engine_kwargs ) expected = DataFrame([{"A": 1, "B": 2, "C": 1.0}]).set_index(["A", "B"]) tm.assert_frame_equal(result, expected) @td.skip_if_no("numba") def test_multiindex_multi_key_not_supported(nogil, parallel, nopython): def numba_func(values, index): return 1 df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"]) engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} with pytest.raises(NotImplementedError, match="More than 1 grouping labels"): df.groupby(["A", "B"]).transform( numba_func, engine="numba", engine_kwargs=engine_kwargs )