112 lines
3.0 KiB
Python
112 lines
3.0 KiB
Python
"""Common utilities for Numba operations"""
|
|
from distutils.version import LooseVersion
|
|
import types
|
|
from typing import Callable, Dict, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from pandas.compat._optional import import_optional_dependency
|
|
from pandas.errors import NumbaUtilError
|
|
|
|
GLOBAL_USE_NUMBA: bool = False
|
|
NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = {}
|
|
|
|
|
|
def maybe_use_numba(engine: Optional[str]) -> bool:
|
|
"""Signal whether to use numba routines."""
|
|
return engine == "numba" or (engine is None and GLOBAL_USE_NUMBA)
|
|
|
|
|
|
def set_use_numba(enable: bool = False) -> None:
|
|
global GLOBAL_USE_NUMBA
|
|
if enable:
|
|
import_optional_dependency("numba")
|
|
GLOBAL_USE_NUMBA = enable
|
|
|
|
|
|
def get_jit_arguments(
|
|
engine_kwargs: Optional[Dict[str, bool]] = None, kwargs: Optional[Dict] = None
|
|
) -> Tuple[bool, bool, bool]:
|
|
"""
|
|
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
|
|
|
|
Parameters
|
|
----------
|
|
engine_kwargs : dict, default None
|
|
user passed keyword arguments for numba.JIT
|
|
kwargs : dict, default None
|
|
user passed keyword arguments to pass into the JITed function
|
|
|
|
Returns
|
|
-------
|
|
(bool, bool, bool)
|
|
nopython, nogil, parallel
|
|
|
|
Raises
|
|
------
|
|
NumbaUtilError
|
|
"""
|
|
if engine_kwargs is None:
|
|
engine_kwargs = {}
|
|
|
|
nopython = engine_kwargs.get("nopython", True)
|
|
if kwargs and nopython:
|
|
raise NumbaUtilError(
|
|
"numba does not support kwargs with nopython=True: "
|
|
"https://github.com/numba/numba/issues/2916"
|
|
)
|
|
nogil = engine_kwargs.get("nogil", False)
|
|
parallel = engine_kwargs.get("parallel", False)
|
|
return nopython, nogil, parallel
|
|
|
|
|
|
def jit_user_function(
|
|
func: Callable, nopython: bool, nogil: bool, parallel: bool
|
|
) -> Callable:
|
|
"""
|
|
JIT the user's function given the configurable arguments.
|
|
|
|
Parameters
|
|
----------
|
|
func : function
|
|
user defined function
|
|
nopython : bool
|
|
nopython parameter for numba.JIT
|
|
nogil : bool
|
|
nogil parameter for numba.JIT
|
|
parallel : bool
|
|
parallel parameter for numba.JIT
|
|
|
|
Returns
|
|
-------
|
|
function
|
|
Numba JITed function
|
|
"""
|
|
numba = import_optional_dependency("numba")
|
|
|
|
if LooseVersion(numba.__version__) >= LooseVersion("0.49.0"):
|
|
is_jitted = numba.extending.is_jitted(func)
|
|
else:
|
|
is_jitted = isinstance(func, numba.targets.registry.CPUDispatcher)
|
|
|
|
if is_jitted:
|
|
# Don't jit a user passed jitted function
|
|
numba_func = func
|
|
else:
|
|
|
|
@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
|
|
def numba_func(data, *_args):
|
|
if getattr(np, func.__name__, False) is func or isinstance(
|
|
func, types.BuiltinFunctionType
|
|
):
|
|
jf = func
|
|
else:
|
|
jf = numba.jit(func, nopython=nopython, nogil=nogil)
|
|
|
|
def impl(data, *_args):
|
|
return jf(data, *_args)
|
|
|
|
return impl
|
|
|
|
return numba_func
|