3RNN/Lib/site-packages/pandas/core/window/numba_.py

352 lines
10 KiB
Python
Raw Normal View History

2024-05-26 19:49:15 +02:00
from __future__ import annotations
import functools
from typing import (
TYPE_CHECKING,
Any,
Callable,
)
import numpy as np
from pandas.compat._optional import import_optional_dependency
from pandas.core.util.numba_ import jit_user_function
if TYPE_CHECKING:
from pandas._typing import Scalar
@functools.cache
def generate_numba_apply_func(
func: Callable[..., Scalar],
nopython: bool,
nogil: bool,
parallel: bool,
):
"""
Generate a numba jitted apply function specified by values from engine_kwargs.
1. jit the user's function
2. Return a rolling apply function with the jitted function inline
Configurations specified in engine_kwargs apply to both the user's
function _AND_ the rolling apply function.
Parameters
----------
func : function
function to be applied to each window and will be JITed
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
Returns
-------
Numba function
"""
numba_func = jit_user_function(func)
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def roll_apply(
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
minimum_periods: int,
*args: Any,
) -> np.ndarray:
result = np.empty(len(begin))
for i in numba.prange(len(result)):
start = begin[i]
stop = end[i]
window = values[start:stop]
count_nan = np.sum(np.isnan(window))
if len(window) - count_nan >= minimum_periods:
result[i] = numba_func(window, *args)
else:
result[i] = np.nan
return result
return roll_apply
@functools.cache
def generate_numba_ewm_func(
nopython: bool,
nogil: bool,
parallel: bool,
com: float,
adjust: bool,
ignore_na: bool,
deltas: tuple,
normalize: bool,
):
"""
Generate a numba jitted ewm mean or sum function specified by values
from engine_kwargs.
Parameters
----------
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
com : float
adjust : bool
ignore_na : bool
deltas : tuple
normalize : bool
Returns
-------
Numba function
"""
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def ewm(
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
minimum_periods: int,
) -> np.ndarray:
result = np.empty(len(values))
alpha = 1.0 / (1.0 + com)
old_wt_factor = 1.0 - alpha
new_wt = 1.0 if adjust else alpha
for i in numba.prange(len(begin)):
start = begin[i]
stop = end[i]
window = values[start:stop]
sub_result = np.empty(len(window))
weighted = window[0]
nobs = int(not np.isnan(weighted))
sub_result[0] = weighted if nobs >= minimum_periods else np.nan
old_wt = 1.0
for j in range(1, len(window)):
cur = window[j]
is_observation = not np.isnan(cur)
nobs += is_observation
if not np.isnan(weighted):
if is_observation or not ignore_na:
if normalize:
# note that len(deltas) = len(vals) - 1 and deltas[i]
# is to be used in conjunction with vals[i+1]
old_wt *= old_wt_factor ** deltas[start + j - 1]
else:
weighted = old_wt_factor * weighted
if is_observation:
if normalize:
# avoid numerical errors on constant series
if weighted != cur:
weighted = old_wt * weighted + new_wt * cur
if normalize:
weighted = weighted / (old_wt + new_wt)
if adjust:
old_wt += new_wt
else:
old_wt = 1.0
else:
weighted += cur
elif is_observation:
weighted = cur
sub_result[j] = weighted if nobs >= minimum_periods else np.nan
result[start:stop] = sub_result
return result
return ewm
@functools.cache
def generate_numba_table_func(
func: Callable[..., np.ndarray],
nopython: bool,
nogil: bool,
parallel: bool,
):
"""
Generate a numba jitted function to apply window calculations table-wise.
Func will be passed a M window size x N number of columns array, and
must return a 1 x N number of columns array. Func is intended to operate
row-wise, but the result will be transposed for axis=1.
1. jit the user's function
2. Return a rolling apply function with the jitted function inline
Parameters
----------
func : function
function to be applied to each window and will be JITed
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
Returns
-------
Numba function
"""
numba_func = jit_user_function(func)
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def roll_table(
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
minimum_periods: int,
*args: Any,
):
result = np.empty((len(begin), values.shape[1]))
min_periods_mask = np.empty(result.shape)
for i in numba.prange(len(result)):
start = begin[i]
stop = end[i]
window = values[start:stop]
count_nan = np.sum(np.isnan(window), axis=0)
sub_result = numba_func(window, *args)
nan_mask = len(window) - count_nan >= minimum_periods
min_periods_mask[i, :] = nan_mask
result[i, :] = sub_result
result = np.where(min_periods_mask, result, np.nan)
return result
return roll_table
# This function will no longer be needed once numba supports
# axis for all np.nan* agg functions
# https://github.com/numba/numba/issues/1269
@functools.cache
def generate_manual_numpy_nan_agg_with_axis(nan_func):
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
@numba.jit(nopython=True, nogil=True, parallel=True)
def nan_agg_with_axis(table):
result = np.empty(table.shape[1])
for i in numba.prange(table.shape[1]):
partition = table[:, i]
result[i] = nan_func(partition)
return result
return nan_agg_with_axis
@functools.cache
def generate_numba_ewm_table_func(
nopython: bool,
nogil: bool,
parallel: bool,
com: float,
adjust: bool,
ignore_na: bool,
deltas: tuple,
normalize: bool,
):
"""
Generate a numba jitted ewm mean or sum function applied table wise specified
by values from engine_kwargs.
Parameters
----------
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
com : float
adjust : bool
ignore_na : bool
deltas : tuple
normalize: bool
Returns
-------
Numba function
"""
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def ewm_table(
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
minimum_periods: int,
) -> np.ndarray:
alpha = 1.0 / (1.0 + com)
old_wt_factor = 1.0 - alpha
new_wt = 1.0 if adjust else alpha
old_wt = np.ones(values.shape[1])
result = np.empty(values.shape)
weighted = values[0].copy()
nobs = (~np.isnan(weighted)).astype(np.int64)
result[0] = np.where(nobs >= minimum_periods, weighted, np.nan)
for i in range(1, len(values)):
cur = values[i]
is_observations = ~np.isnan(cur)
nobs += is_observations.astype(np.int64)
for j in numba.prange(len(cur)):
if not np.isnan(weighted[j]):
if is_observations[j] or not ignore_na:
if normalize:
# note that len(deltas) = len(vals) - 1 and deltas[i]
# is to be used in conjunction with vals[i+1]
old_wt[j] *= old_wt_factor ** deltas[i - 1]
else:
weighted[j] = old_wt_factor * weighted[j]
if is_observations[j]:
if normalize:
# avoid numerical errors on constant series
if weighted[j] != cur[j]:
weighted[j] = (
old_wt[j] * weighted[j] + new_wt * cur[j]
)
if normalize:
weighted[j] = weighted[j] / (old_wt[j] + new_wt)
if adjust:
old_wt[j] += new_wt
else:
old_wt[j] = 1.0
else:
weighted[j] += cur[j]
elif is_observations[j]:
weighted[j] = cur[j]
result[i] = np.where(nobs >= minimum_periods, weighted, np.nan)
return result
return ewm_table