3RNN/Lib/site-packages/pandas/core/groupby/numba_.py
2024-05-26 19:49:15 +02:00

182 lines
4.8 KiB
Python

"""Common utilities for Numba operations with groupby ops"""
from __future__ import annotations
import functools
import inspect
from typing import (
TYPE_CHECKING,
Any,
Callable,
)
import numpy as np
from pandas.compat._optional import import_optional_dependency
from pandas.core.util.numba_ import (
NumbaUtilError,
jit_user_function,
)
if TYPE_CHECKING:
from pandas._typing import Scalar
def validate_udf(func: Callable) -> None:
"""
Validate user defined function for ops when using Numba with groupby ops.
The first signature arguments should include:
def f(values, index, ...):
...
Parameters
----------
func : function, default False
user defined function
Returns
-------
None
Raises
------
NumbaUtilError
"""
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
udf_signature = list(inspect.signature(func).parameters.keys())
expected_args = ["values", "index"]
min_number_args = len(expected_args)
if (
len(udf_signature) < min_number_args
or udf_signature[:min_number_args] != expected_args
):
raise NumbaUtilError(
f"The first {min_number_args} arguments to {func.__name__} must be "
f"{expected_args}"
)
@functools.cache
def generate_numba_agg_func(
func: Callable[..., Scalar],
nopython: bool,
nogil: bool,
parallel: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
"""
Generate a numba jitted agg function specified by values from engine_kwargs.
1. jit the user's function
2. Return a groupby agg function with the jitted function inline
Configurations specified in engine_kwargs apply to both the user's
function _AND_ the groupby evaluation loop.
Parameters
----------
func : function
function to be applied to each group and will be JITed
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
Returns
-------
Numba function
"""
numba_func = jit_user_function(func)
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def group_agg(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
num_columns: int,
*args: Any,
) -> np.ndarray:
assert len(begin) == len(end)
num_groups = len(begin)
result = np.empty((num_groups, num_columns))
for i in numba.prange(num_groups):
group_index = index[begin[i] : end[i]]
for j in numba.prange(num_columns):
group = values[begin[i] : end[i], j]
result[i, j] = numba_func(group, group_index, *args)
return result
return group_agg
@functools.cache
def generate_numba_transform_func(
func: Callable[..., np.ndarray],
nopython: bool,
nogil: bool,
parallel: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
"""
Generate a numba jitted transform function specified by values from engine_kwargs.
1. jit the user's function
2. Return a groupby transform function with the jitted function inline
Configurations specified in engine_kwargs apply to both the user's
function _AND_ the groupby evaluation loop.
Parameters
----------
func : function
function to be applied to each window and will be JITed
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
Returns
-------
Numba function
"""
numba_func = jit_user_function(func)
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def group_transform(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
num_columns: int,
*args: Any,
) -> np.ndarray:
assert len(begin) == len(end)
num_groups = len(begin)
result = np.empty((len(values), num_columns))
for i in numba.prange(num_groups):
group_index = index[begin[i] : end[i]]
for j in numba.prange(num_columns):
group = values[begin[i] : end[i], j]
result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
return result
return group_transform