164 lines
4.7 KiB
Python
164 lines
4.7 KiB
Python
from typing import Any, Callable, Dict, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from pandas._typing import Scalar
|
|
from pandas.compat._optional import import_optional_dependency
|
|
|
|
from pandas.core.util.numba_ import (
|
|
NUMBA_FUNC_CACHE,
|
|
get_jit_arguments,
|
|
jit_user_function,
|
|
)
|
|
|
|
|
|
def generate_numba_apply_func(
|
|
args: Tuple,
|
|
kwargs: Dict[str, Any],
|
|
func: Callable[..., Scalar],
|
|
engine_kwargs: Optional[Dict[str, bool]],
|
|
):
|
|
"""
|
|
Generate a numba jitted apply function specified by values from engine_kwargs.
|
|
|
|
1. jit the user's function
|
|
2. Return a rolling apply function with the jitted function inline
|
|
|
|
Configurations specified in engine_kwargs apply to both the user's
|
|
function _AND_ the rolling apply function.
|
|
|
|
Parameters
|
|
----------
|
|
args : tuple
|
|
*args to be passed into the function
|
|
kwargs : dict
|
|
**kwargs to be passed into the function
|
|
func : function
|
|
function to be applied to each window and will be JITed
|
|
engine_kwargs : dict
|
|
dictionary of arguments to be passed into numba.jit
|
|
|
|
Returns
|
|
-------
|
|
Numba function
|
|
"""
|
|
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
|
|
|
|
cache_key = (func, "rolling_apply")
|
|
if cache_key in NUMBA_FUNC_CACHE:
|
|
return NUMBA_FUNC_CACHE[cache_key]
|
|
|
|
numba_func = jit_user_function(func, nopython, nogil, parallel)
|
|
numba = import_optional_dependency("numba")
|
|
if parallel:
|
|
loop_range = numba.prange
|
|
else:
|
|
loop_range = range
|
|
|
|
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
|
|
def roll_apply(
|
|
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int
|
|
) -> np.ndarray:
|
|
result = np.empty(len(begin))
|
|
for i in loop_range(len(result)):
|
|
start = begin[i]
|
|
stop = end[i]
|
|
window = values[start:stop]
|
|
count_nan = np.sum(np.isnan(window))
|
|
if len(window) - count_nan >= minimum_periods:
|
|
result[i] = numba_func(window, *args)
|
|
else:
|
|
result[i] = np.nan
|
|
return result
|
|
|
|
return roll_apply
|
|
|
|
|
|
def generate_numba_groupby_ewma_func(
|
|
engine_kwargs: Optional[Dict[str, bool]],
|
|
com: float,
|
|
adjust: bool,
|
|
ignore_na: bool,
|
|
):
|
|
"""
|
|
Generate a numba jitted groupby ewma function specified by values
|
|
from engine_kwargs.
|
|
|
|
Parameters
|
|
----------
|
|
engine_kwargs : dict
|
|
dictionary of arguments to be passed into numba.jit
|
|
com : float
|
|
adjust : bool
|
|
ignore_na : bool
|
|
|
|
Returns
|
|
-------
|
|
Numba function
|
|
"""
|
|
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
|
|
|
|
cache_key = (lambda x: x, "groupby_ewma")
|
|
if cache_key in NUMBA_FUNC_CACHE:
|
|
return NUMBA_FUNC_CACHE[cache_key]
|
|
|
|
numba = import_optional_dependency("numba")
|
|
if parallel:
|
|
loop_range = numba.prange
|
|
else:
|
|
loop_range = range
|
|
|
|
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
|
|
def groupby_ewma(
|
|
values: np.ndarray,
|
|
begin: np.ndarray,
|
|
end: np.ndarray,
|
|
minimum_periods: int,
|
|
) -> np.ndarray:
|
|
result = np.empty(len(values))
|
|
alpha = 1.0 / (1.0 + com)
|
|
for i in loop_range(len(begin)):
|
|
start = begin[i]
|
|
stop = end[i]
|
|
window = values[start:stop]
|
|
sub_result = np.empty(len(window))
|
|
|
|
old_wt_factor = 1.0 - alpha
|
|
new_wt = 1.0 if adjust else alpha
|
|
|
|
weighted_avg = window[0]
|
|
nobs = int(not np.isnan(weighted_avg))
|
|
sub_result[0] = weighted_avg if nobs >= minimum_periods else np.nan
|
|
old_wt = 1.0
|
|
|
|
for j in range(1, len(window)):
|
|
cur = window[j]
|
|
is_observation = not np.isnan(cur)
|
|
nobs += is_observation
|
|
if not np.isnan(weighted_avg):
|
|
|
|
if is_observation or not ignore_na:
|
|
|
|
old_wt *= old_wt_factor
|
|
if is_observation:
|
|
|
|
# avoid numerical errors on constant series
|
|
if weighted_avg != cur:
|
|
weighted_avg = (
|
|
(old_wt * weighted_avg) + (new_wt * cur)
|
|
) / (old_wt + new_wt)
|
|
if adjust:
|
|
old_wt += new_wt
|
|
else:
|
|
old_wt = 1.0
|
|
elif is_observation:
|
|
weighted_avg = cur
|
|
|
|
sub_result[j] = weighted_avg if nobs >= minimum_periods else np.nan
|
|
|
|
result[start:stop] = sub_result
|
|
|
|
return result
|
|
|
|
return groupby_ewma
|