projektAI/venv/Lib/site-packages/pandas/core/util/numba_.py

112 lines
3.0 KiB
Python
Raw Normal View History

2021-06-06 22:13:05 +02:00
"""Common utilities for Numba operations"""
from distutils.version import LooseVersion
import types
from typing import Callable, Dict, Optional, Tuple
import numpy as np
from pandas.compat._optional import import_optional_dependency
from pandas.errors import NumbaUtilError
GLOBAL_USE_NUMBA: bool = False
NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = {}
def maybe_use_numba(engine: Optional[str]) -> bool:
"""Signal whether to use numba routines."""
return engine == "numba" or (engine is None and GLOBAL_USE_NUMBA)
def set_use_numba(enable: bool = False) -> None:
global GLOBAL_USE_NUMBA
if enable:
import_optional_dependency("numba")
GLOBAL_USE_NUMBA = enable
def get_jit_arguments(
engine_kwargs: Optional[Dict[str, bool]] = None, kwargs: Optional[Dict] = None
) -> Tuple[bool, bool, bool]:
"""
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
Parameters
----------
engine_kwargs : dict, default None
user passed keyword arguments for numba.JIT
kwargs : dict, default None
user passed keyword arguments to pass into the JITed function
Returns
-------
(bool, bool, bool)
nopython, nogil, parallel
Raises
------
NumbaUtilError
"""
if engine_kwargs is None:
engine_kwargs = {}
nopython = engine_kwargs.get("nopython", True)
if kwargs and nopython:
raise NumbaUtilError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
return nopython, nogil, parallel
def jit_user_function(
func: Callable, nopython: bool, nogil: bool, parallel: bool
) -> Callable:
"""
JIT the user's function given the configurable arguments.
Parameters
----------
func : function
user defined function
nopython : bool
nopython parameter for numba.JIT
nogil : bool
nogil parameter for numba.JIT
parallel : bool
parallel parameter for numba.JIT
Returns
-------
function
Numba JITed function
"""
numba = import_optional_dependency("numba")
if LooseVersion(numba.__version__) >= LooseVersion("0.49.0"):
is_jitted = numba.extending.is_jitted(func)
else:
is_jitted = isinstance(func, numba.targets.registry.CPUDispatcher)
if is_jitted:
# Don't jit a user passed jitted function
numba_func = func
else:
@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
def numba_func(data, *_args):
if getattr(np, func.__name__, False) is func or isinstance(
func, types.BuiltinFunctionType
):
jf = func
else:
jf = numba.jit(func, nopython=nopython, nogil=nogil)
def impl(data, *_args):
return jf(data, *_args)
return impl
return numba_func