Inzynierka/Lib/site-packages/pandas/core/window/online.py
2023-06-02 12:51:02 +02:00

119 lines
3.6 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from pandas.compat._optional import import_optional_dependency
def generate_online_numba_ewma_func(
nopython: bool,
nogil: bool,
parallel: bool,
):
"""
Generate a numba jitted groupby ewma function specified by values
from engine_kwargs.
Parameters
----------
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit
Returns
-------
Numba function
"""
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def online_ewma(
values: np.ndarray,
deltas: np.ndarray,
minimum_periods: int,
old_wt_factor: float,
new_wt: float,
old_wt: np.ndarray,
adjust: bool,
ignore_na: bool,
):
"""
Compute online exponentially weighted mean per column over 2D values.
Takes the first observation as is, then computes the subsequent
exponentially weighted mean accounting minimum periods.
"""
result = np.empty(values.shape)
weighted_avg = values[0]
nobs = (~np.isnan(weighted_avg)).astype(np.int64)
result[0] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
for i in range(1, len(values)):
cur = values[i]
is_observations = ~np.isnan(cur)
nobs += is_observations.astype(np.int64)
for j in numba.prange(len(cur)):
if not np.isnan(weighted_avg[j]):
if is_observations[j] or not ignore_na:
# note that len(deltas) = len(vals) - 1 and deltas[i] is to be
# used in conjunction with vals[i+1]
old_wt[j] *= old_wt_factor ** deltas[j - 1]
if is_observations[j]:
# avoid numerical errors on constant series
if weighted_avg[j] != cur[j]:
weighted_avg[j] = (
(old_wt[j] * weighted_avg[j]) + (new_wt * cur[j])
) / (old_wt[j] + new_wt)
if adjust:
old_wt[j] += new_wt
else:
old_wt[j] = 1.0
elif is_observations[j]:
weighted_avg[j] = cur[j]
result[i] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
return result, old_wt
return online_ewma
class EWMMeanState:
def __init__(self, com, adjust, ignore_na, axis, shape) -> None:
alpha = 1.0 / (1.0 + com)
self.axis = axis
self.shape = shape
self.adjust = adjust
self.ignore_na = ignore_na
self.new_wt = 1.0 if adjust else alpha
self.old_wt_factor = 1.0 - alpha
self.old_wt = np.ones(self.shape[self.axis - 1])
self.last_ewm = None
def run_ewm(self, weighted_avg, deltas, min_periods, ewm_func):
result, old_wt = ewm_func(
weighted_avg,
deltas,
min_periods,
self.old_wt_factor,
self.new_wt,
self.old_wt,
self.adjust,
self.ignore_na,
)
self.old_wt = old_wt
self.last_ewm = result[-1]
return result
def reset(self) -> None:
self.old_wt = np.ones(self.shape[self.axis - 1])
self.last_ewm = None