352 lines
10 KiB
Python
352 lines
10 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import functools
|
||
|
from typing import (
|
||
|
TYPE_CHECKING,
|
||
|
Any,
|
||
|
Callable,
|
||
|
)
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from pandas.compat._optional import import_optional_dependency
|
||
|
|
||
|
from pandas.core.util.numba_ import jit_user_function
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from pandas._typing import Scalar
|
||
|
|
||
|
|
||
|
@functools.cache
|
||
|
def generate_numba_apply_func(
|
||
|
func: Callable[..., Scalar],
|
||
|
nopython: bool,
|
||
|
nogil: bool,
|
||
|
parallel: bool,
|
||
|
):
|
||
|
"""
|
||
|
Generate a numba jitted apply function specified by values from engine_kwargs.
|
||
|
|
||
|
1. jit the user's function
|
||
|
2. Return a rolling apply function with the jitted function inline
|
||
|
|
||
|
Configurations specified in engine_kwargs apply to both the user's
|
||
|
function _AND_ the rolling apply function.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
func : function
|
||
|
function to be applied to each window and will be JITed
|
||
|
nopython : bool
|
||
|
nopython to be passed into numba.jit
|
||
|
nogil : bool
|
||
|
nogil to be passed into numba.jit
|
||
|
parallel : bool
|
||
|
parallel to be passed into numba.jit
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Numba function
|
||
|
"""
|
||
|
numba_func = jit_user_function(func)
|
||
|
if TYPE_CHECKING:
|
||
|
import numba
|
||
|
else:
|
||
|
numba = import_optional_dependency("numba")
|
||
|
|
||
|
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
|
||
|
def roll_apply(
|
||
|
values: np.ndarray,
|
||
|
begin: np.ndarray,
|
||
|
end: np.ndarray,
|
||
|
minimum_periods: int,
|
||
|
*args: Any,
|
||
|
) -> np.ndarray:
|
||
|
result = np.empty(len(begin))
|
||
|
for i in numba.prange(len(result)):
|
||
|
start = begin[i]
|
||
|
stop = end[i]
|
||
|
window = values[start:stop]
|
||
|
count_nan = np.sum(np.isnan(window))
|
||
|
if len(window) - count_nan >= minimum_periods:
|
||
|
result[i] = numba_func(window, *args)
|
||
|
else:
|
||
|
result[i] = np.nan
|
||
|
return result
|
||
|
|
||
|
return roll_apply
|
||
|
|
||
|
|
||
|
@functools.cache
|
||
|
def generate_numba_ewm_func(
|
||
|
nopython: bool,
|
||
|
nogil: bool,
|
||
|
parallel: bool,
|
||
|
com: float,
|
||
|
adjust: bool,
|
||
|
ignore_na: bool,
|
||
|
deltas: tuple,
|
||
|
normalize: bool,
|
||
|
):
|
||
|
"""
|
||
|
Generate a numba jitted ewm mean or sum function specified by values
|
||
|
from engine_kwargs.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
nopython : bool
|
||
|
nopython to be passed into numba.jit
|
||
|
nogil : bool
|
||
|
nogil to be passed into numba.jit
|
||
|
parallel : bool
|
||
|
parallel to be passed into numba.jit
|
||
|
com : float
|
||
|
adjust : bool
|
||
|
ignore_na : bool
|
||
|
deltas : tuple
|
||
|
normalize : bool
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Numba function
|
||
|
"""
|
||
|
if TYPE_CHECKING:
|
||
|
import numba
|
||
|
else:
|
||
|
numba = import_optional_dependency("numba")
|
||
|
|
||
|
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
|
||
|
def ewm(
|
||
|
values: np.ndarray,
|
||
|
begin: np.ndarray,
|
||
|
end: np.ndarray,
|
||
|
minimum_periods: int,
|
||
|
) -> np.ndarray:
|
||
|
result = np.empty(len(values))
|
||
|
alpha = 1.0 / (1.0 + com)
|
||
|
old_wt_factor = 1.0 - alpha
|
||
|
new_wt = 1.0 if adjust else alpha
|
||
|
|
||
|
for i in numba.prange(len(begin)):
|
||
|
start = begin[i]
|
||
|
stop = end[i]
|
||
|
window = values[start:stop]
|
||
|
sub_result = np.empty(len(window))
|
||
|
|
||
|
weighted = window[0]
|
||
|
nobs = int(not np.isnan(weighted))
|
||
|
sub_result[0] = weighted if nobs >= minimum_periods else np.nan
|
||
|
old_wt = 1.0
|
||
|
|
||
|
for j in range(1, len(window)):
|
||
|
cur = window[j]
|
||
|
is_observation = not np.isnan(cur)
|
||
|
nobs += is_observation
|
||
|
if not np.isnan(weighted):
|
||
|
if is_observation or not ignore_na:
|
||
|
if normalize:
|
||
|
# note that len(deltas) = len(vals) - 1 and deltas[i]
|
||
|
# is to be used in conjunction with vals[i+1]
|
||
|
old_wt *= old_wt_factor ** deltas[start + j - 1]
|
||
|
else:
|
||
|
weighted = old_wt_factor * weighted
|
||
|
if is_observation:
|
||
|
if normalize:
|
||
|
# avoid numerical errors on constant series
|
||
|
if weighted != cur:
|
||
|
weighted = old_wt * weighted + new_wt * cur
|
||
|
if normalize:
|
||
|
weighted = weighted / (old_wt + new_wt)
|
||
|
if adjust:
|
||
|
old_wt += new_wt
|
||
|
else:
|
||
|
old_wt = 1.0
|
||
|
else:
|
||
|
weighted += cur
|
||
|
elif is_observation:
|
||
|
weighted = cur
|
||
|
|
||
|
sub_result[j] = weighted if nobs >= minimum_periods else np.nan
|
||
|
|
||
|
result[start:stop] = sub_result
|
||
|
|
||
|
return result
|
||
|
|
||
|
return ewm
|
||
|
|
||
|
|
||
|
@functools.cache
|
||
|
def generate_numba_table_func(
|
||
|
func: Callable[..., np.ndarray],
|
||
|
nopython: bool,
|
||
|
nogil: bool,
|
||
|
parallel: bool,
|
||
|
):
|
||
|
"""
|
||
|
Generate a numba jitted function to apply window calculations table-wise.
|
||
|
|
||
|
Func will be passed a M window size x N number of columns array, and
|
||
|
must return a 1 x N number of columns array. Func is intended to operate
|
||
|
row-wise, but the result will be transposed for axis=1.
|
||
|
|
||
|
1. jit the user's function
|
||
|
2. Return a rolling apply function with the jitted function inline
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
func : function
|
||
|
function to be applied to each window and will be JITed
|
||
|
nopython : bool
|
||
|
nopython to be passed into numba.jit
|
||
|
nogil : bool
|
||
|
nogil to be passed into numba.jit
|
||
|
parallel : bool
|
||
|
parallel to be passed into numba.jit
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Numba function
|
||
|
"""
|
||
|
numba_func = jit_user_function(func)
|
||
|
if TYPE_CHECKING:
|
||
|
import numba
|
||
|
else:
|
||
|
numba = import_optional_dependency("numba")
|
||
|
|
||
|
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
|
||
|
def roll_table(
|
||
|
values: np.ndarray,
|
||
|
begin: np.ndarray,
|
||
|
end: np.ndarray,
|
||
|
minimum_periods: int,
|
||
|
*args: Any,
|
||
|
):
|
||
|
result = np.empty((len(begin), values.shape[1]))
|
||
|
min_periods_mask = np.empty(result.shape)
|
||
|
for i in numba.prange(len(result)):
|
||
|
start = begin[i]
|
||
|
stop = end[i]
|
||
|
window = values[start:stop]
|
||
|
count_nan = np.sum(np.isnan(window), axis=0)
|
||
|
sub_result = numba_func(window, *args)
|
||
|
nan_mask = len(window) - count_nan >= minimum_periods
|
||
|
min_periods_mask[i, :] = nan_mask
|
||
|
result[i, :] = sub_result
|
||
|
result = np.where(min_periods_mask, result, np.nan)
|
||
|
return result
|
||
|
|
||
|
return roll_table
|
||
|
|
||
|
|
||
|
# This function will no longer be needed once numba supports
|
||
|
# axis for all np.nan* agg functions
|
||
|
# https://github.com/numba/numba/issues/1269
|
||
|
@functools.cache
|
||
|
def generate_manual_numpy_nan_agg_with_axis(nan_func):
|
||
|
if TYPE_CHECKING:
|
||
|
import numba
|
||
|
else:
|
||
|
numba = import_optional_dependency("numba")
|
||
|
|
||
|
@numba.jit(nopython=True, nogil=True, parallel=True)
|
||
|
def nan_agg_with_axis(table):
|
||
|
result = np.empty(table.shape[1])
|
||
|
for i in numba.prange(table.shape[1]):
|
||
|
partition = table[:, i]
|
||
|
result[i] = nan_func(partition)
|
||
|
return result
|
||
|
|
||
|
return nan_agg_with_axis
|
||
|
|
||
|
|
||
|
@functools.cache
|
||
|
def generate_numba_ewm_table_func(
|
||
|
nopython: bool,
|
||
|
nogil: bool,
|
||
|
parallel: bool,
|
||
|
com: float,
|
||
|
adjust: bool,
|
||
|
ignore_na: bool,
|
||
|
deltas: tuple,
|
||
|
normalize: bool,
|
||
|
):
|
||
|
"""
|
||
|
Generate a numba jitted ewm mean or sum function applied table wise specified
|
||
|
by values from engine_kwargs.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
nopython : bool
|
||
|
nopython to be passed into numba.jit
|
||
|
nogil : bool
|
||
|
nogil to be passed into numba.jit
|
||
|
parallel : bool
|
||
|
parallel to be passed into numba.jit
|
||
|
com : float
|
||
|
adjust : bool
|
||
|
ignore_na : bool
|
||
|
deltas : tuple
|
||
|
normalize: bool
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Numba function
|
||
|
"""
|
||
|
if TYPE_CHECKING:
|
||
|
import numba
|
||
|
else:
|
||
|
numba = import_optional_dependency("numba")
|
||
|
|
||
|
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
|
||
|
def ewm_table(
|
||
|
values: np.ndarray,
|
||
|
begin: np.ndarray,
|
||
|
end: np.ndarray,
|
||
|
minimum_periods: int,
|
||
|
) -> np.ndarray:
|
||
|
alpha = 1.0 / (1.0 + com)
|
||
|
old_wt_factor = 1.0 - alpha
|
||
|
new_wt = 1.0 if adjust else alpha
|
||
|
old_wt = np.ones(values.shape[1])
|
||
|
|
||
|
result = np.empty(values.shape)
|
||
|
weighted = values[0].copy()
|
||
|
nobs = (~np.isnan(weighted)).astype(np.int64)
|
||
|
result[0] = np.where(nobs >= minimum_periods, weighted, np.nan)
|
||
|
for i in range(1, len(values)):
|
||
|
cur = values[i]
|
||
|
is_observations = ~np.isnan(cur)
|
||
|
nobs += is_observations.astype(np.int64)
|
||
|
for j in numba.prange(len(cur)):
|
||
|
if not np.isnan(weighted[j]):
|
||
|
if is_observations[j] or not ignore_na:
|
||
|
if normalize:
|
||
|
# note that len(deltas) = len(vals) - 1 and deltas[i]
|
||
|
# is to be used in conjunction with vals[i+1]
|
||
|
old_wt[j] *= old_wt_factor ** deltas[i - 1]
|
||
|
else:
|
||
|
weighted[j] = old_wt_factor * weighted[j]
|
||
|
if is_observations[j]:
|
||
|
if normalize:
|
||
|
# avoid numerical errors on constant series
|
||
|
if weighted[j] != cur[j]:
|
||
|
weighted[j] = (
|
||
|
old_wt[j] * weighted[j] + new_wt * cur[j]
|
||
|
)
|
||
|
if normalize:
|
||
|
weighted[j] = weighted[j] / (old_wt[j] + new_wt)
|
||
|
if adjust:
|
||
|
old_wt[j] += new_wt
|
||
|
else:
|
||
|
old_wt[j] = 1.0
|
||
|
else:
|
||
|
weighted[j] += cur[j]
|
||
|
elif is_observations[j]:
|
||
|
weighted[j] = cur[j]
|
||
|
|
||
|
result[i] = np.where(nobs >= minimum_periods, weighted, np.nan)
|
||
|
|
||
|
return result
|
||
|
|
||
|
return ewm_table
|