179 lines
5.1 KiB
Python
179 lines
5.1 KiB
Python
"""Common utilities for Numba operations with groupby ops"""
|
|
import inspect
|
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from pandas._typing import Scalar
|
|
from pandas.compat._optional import import_optional_dependency
|
|
|
|
from pandas.core.util.numba_ import (
|
|
NUMBA_FUNC_CACHE,
|
|
NumbaUtilError,
|
|
get_jit_arguments,
|
|
jit_user_function,
|
|
)
|
|
|
|
|
|
def validate_udf(func: Callable) -> None:
|
|
"""
|
|
Validate user defined function for ops when using Numba with groupby ops.
|
|
|
|
The first signature arguments should include:
|
|
|
|
def f(values, index, ...):
|
|
...
|
|
|
|
Parameters
|
|
----------
|
|
func : function, default False
|
|
user defined function
|
|
|
|
Returns
|
|
-------
|
|
None
|
|
|
|
Raises
|
|
------
|
|
NumbaUtilError
|
|
"""
|
|
udf_signature = list(inspect.signature(func).parameters.keys())
|
|
expected_args = ["values", "index"]
|
|
min_number_args = len(expected_args)
|
|
if (
|
|
len(udf_signature) < min_number_args
|
|
or udf_signature[:min_number_args] != expected_args
|
|
):
|
|
raise NumbaUtilError(
|
|
f"The first {min_number_args} arguments to {func.__name__} must be "
|
|
f"{expected_args}"
|
|
)
|
|
|
|
|
|
def generate_numba_agg_func(
|
|
args: Tuple,
|
|
kwargs: Dict[str, Any],
|
|
func: Callable[..., Scalar],
|
|
engine_kwargs: Optional[Dict[str, bool]],
|
|
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
|
|
"""
|
|
Generate a numba jitted agg function specified by values from engine_kwargs.
|
|
|
|
1. jit the user's function
|
|
2. Return a groupby agg function with the jitted function inline
|
|
|
|
Configurations specified in engine_kwargs apply to both the user's
|
|
function _AND_ the groupby evaluation loop.
|
|
|
|
Parameters
|
|
----------
|
|
args : tuple
|
|
*args to be passed into the function
|
|
kwargs : dict
|
|
**kwargs to be passed into the function
|
|
func : function
|
|
function to be applied to each window and will be JITed
|
|
engine_kwargs : dict
|
|
dictionary of arguments to be passed into numba.jit
|
|
|
|
Returns
|
|
-------
|
|
Numba function
|
|
"""
|
|
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
|
|
|
|
validate_udf(func)
|
|
cache_key = (func, "groupby_agg")
|
|
if cache_key in NUMBA_FUNC_CACHE:
|
|
return NUMBA_FUNC_CACHE[cache_key]
|
|
|
|
numba_func = jit_user_function(func, nopython, nogil, parallel)
|
|
numba = import_optional_dependency("numba")
|
|
if parallel:
|
|
loop_range = numba.prange
|
|
else:
|
|
loop_range = range
|
|
|
|
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
|
|
def group_agg(
|
|
values: np.ndarray,
|
|
index: np.ndarray,
|
|
begin: np.ndarray,
|
|
end: np.ndarray,
|
|
num_groups: int,
|
|
num_columns: int,
|
|
) -> np.ndarray:
|
|
result = np.empty((num_groups, num_columns))
|
|
for i in loop_range(num_groups):
|
|
group_index = index[begin[i] : end[i]]
|
|
for j in loop_range(num_columns):
|
|
group = values[begin[i] : end[i], j]
|
|
result[i, j] = numba_func(group, group_index, *args)
|
|
return result
|
|
|
|
return group_agg
|
|
|
|
|
|
def generate_numba_transform_func(
|
|
args: Tuple,
|
|
kwargs: Dict[str, Any],
|
|
func: Callable[..., np.ndarray],
|
|
engine_kwargs: Optional[Dict[str, bool]],
|
|
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
|
|
"""
|
|
Generate a numba jitted transform function specified by values from engine_kwargs.
|
|
|
|
1. jit the user's function
|
|
2. Return a groupby transform function with the jitted function inline
|
|
|
|
Configurations specified in engine_kwargs apply to both the user's
|
|
function _AND_ the groupby evaluation loop.
|
|
|
|
Parameters
|
|
----------
|
|
args : tuple
|
|
*args to be passed into the function
|
|
kwargs : dict
|
|
**kwargs to be passed into the function
|
|
func : function
|
|
function to be applied to each window and will be JITed
|
|
engine_kwargs : dict
|
|
dictionary of arguments to be passed into numba.jit
|
|
|
|
Returns
|
|
-------
|
|
Numba function
|
|
"""
|
|
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
|
|
|
|
validate_udf(func)
|
|
cache_key = (func, "groupby_transform")
|
|
if cache_key in NUMBA_FUNC_CACHE:
|
|
return NUMBA_FUNC_CACHE[cache_key]
|
|
|
|
numba_func = jit_user_function(func, nopython, nogil, parallel)
|
|
numba = import_optional_dependency("numba")
|
|
if parallel:
|
|
loop_range = numba.prange
|
|
else:
|
|
loop_range = range
|
|
|
|
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
|
|
def group_transform(
|
|
values: np.ndarray,
|
|
index: np.ndarray,
|
|
begin: np.ndarray,
|
|
end: np.ndarray,
|
|
num_groups: int,
|
|
num_columns: int,
|
|
) -> np.ndarray:
|
|
result = np.empty((len(values), num_columns))
|
|
for i in loop_range(num_groups):
|
|
group_index = index[begin[i] : end[i]]
|
|
for j in loop_range(num_columns):
|
|
group = values[begin[i] : end[i], j]
|
|
result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
|
|
return result
|
|
|
|
return group_transform
|