Inzynierka/Lib/site-packages/pandas/core/computation/expressions.py

284 lines
7.3 KiB
Python
Raw Normal View History

2023-06-02 12:51:02 +02:00
"""
Expressions
-----------
Offer fast expression evaluation through numexpr
"""
from __future__ import annotations
import operator
import warnings
import numpy as np
from pandas._config import get_option
from pandas._typing import FuncType
from pandas.util._exceptions import find_stack_level
from pandas.core.computation.check import NUMEXPR_INSTALLED
from pandas.core.ops import roperator
if NUMEXPR_INSTALLED:
import numexpr as ne
_TEST_MODE: bool | None = None
_TEST_RESULT: list[bool] = []
USE_NUMEXPR = NUMEXPR_INSTALLED
_evaluate: FuncType | None = None
_where: FuncType | None = None
# the set of dtypes that we will allow pass to numexpr
_ALLOWED_DTYPES = {
"evaluate": {"int64", "int32", "float64", "float32", "bool"},
"where": {"int64", "float64", "bool"},
}
# the minimum prod shape that we will use numexpr
_MIN_ELEMENTS = 1_000_000
def set_use_numexpr(v: bool = True) -> None:
# set/unset to use numexpr
global USE_NUMEXPR
if NUMEXPR_INSTALLED:
USE_NUMEXPR = v
# choose what we are going to do
global _evaluate, _where
_evaluate = _evaluate_numexpr if USE_NUMEXPR else _evaluate_standard
_where = _where_numexpr if USE_NUMEXPR else _where_standard
def set_numexpr_threads(n=None) -> None:
# if we are using numexpr, set the threads to n
# otherwise reset
if NUMEXPR_INSTALLED and USE_NUMEXPR:
if n is None:
n = ne.detect_number_of_cores()
ne.set_num_threads(n)
def _evaluate_standard(op, op_str, a, b):
"""
Standard evaluation.
"""
if _TEST_MODE:
_store_test_result(False)
return op(a, b)
def _can_use_numexpr(op, op_str, a, b, dtype_check) -> bool:
"""return a boolean if we WILL be using numexpr"""
if op_str is not None:
# required min elements (otherwise we are adding overhead)
if a.size > _MIN_ELEMENTS:
# check for dtype compatibility
dtypes: set[str] = set()
for o in [a, b]:
# ndarray and Series Case
if hasattr(o, "dtype"):
dtypes |= {o.dtype.name}
# allowed are a superset
if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes:
return True
return False
def _evaluate_numexpr(op, op_str, a, b):
result = None
if _can_use_numexpr(op, op_str, a, b, "evaluate"):
is_reversed = op.__name__.strip("_").startswith("r")
if is_reversed:
# we were originally called by a reversed op method
a, b = b, a
a_value = a
b_value = b
try:
result = ne.evaluate(
f"a_value {op_str} b_value",
local_dict={"a_value": a_value, "b_value": b_value},
casting="safe",
)
except TypeError:
# numexpr raises eg for array ** array with integers
# (https://github.com/pydata/numexpr/issues/379)
pass
except NotImplementedError:
if _bool_arith_fallback(op_str, a, b):
pass
else:
raise
if is_reversed:
# reverse order to original for fallback
a, b = b, a
if _TEST_MODE:
_store_test_result(result is not None)
if result is None:
result = _evaluate_standard(op, op_str, a, b)
return result
_op_str_mapping = {
operator.add: "+",
roperator.radd: "+",
operator.mul: "*",
roperator.rmul: "*",
operator.sub: "-",
roperator.rsub: "-",
operator.truediv: "/",
roperator.rtruediv: "/",
# floordiv not supported by numexpr 2.x
operator.floordiv: None,
roperator.rfloordiv: None,
# we require Python semantics for mod of negative for backwards compatibility
# see https://github.com/pydata/numexpr/issues/365
# so sticking with unaccelerated for now GH#36552
operator.mod: None,
roperator.rmod: None,
operator.pow: "**",
roperator.rpow: "**",
operator.eq: "==",
operator.ne: "!=",
operator.le: "<=",
operator.lt: "<",
operator.ge: ">=",
operator.gt: ">",
operator.and_: "&",
roperator.rand_: "&",
operator.or_: "|",
roperator.ror_: "|",
operator.xor: "^",
roperator.rxor: "^",
divmod: None,
roperator.rdivmod: None,
}
def _where_standard(cond, a, b):
# Caller is responsible for extracting ndarray if necessary
return np.where(cond, a, b)
def _where_numexpr(cond, a, b):
# Caller is responsible for extracting ndarray if necessary
result = None
if _can_use_numexpr(None, "where", a, b, "where"):
result = ne.evaluate(
"where(cond_value, a_value, b_value)",
local_dict={"cond_value": cond, "a_value": a, "b_value": b},
casting="safe",
)
if result is None:
result = _where_standard(cond, a, b)
return result
# turn myself on
set_use_numexpr(get_option("compute.use_numexpr"))
def _has_bool_dtype(x):
try:
return x.dtype == bool
except AttributeError:
return isinstance(x, (bool, np.bool_))
_BOOL_OP_UNSUPPORTED = {"+": "|", "*": "&", "-": "^"}
def _bool_arith_fallback(op_str, a, b) -> bool:
"""
Check if we should fallback to the python `_evaluate_standard` in case
of an unsupported operation by numexpr, which is the case for some
boolean ops.
"""
if _has_bool_dtype(a) and _has_bool_dtype(b):
if op_str in _BOOL_OP_UNSUPPORTED:
warnings.warn(
f"evaluating in Python space because the {repr(op_str)} "
"operator is not supported by numexpr for the bool dtype, "
f"use {repr(_BOOL_OP_UNSUPPORTED[op_str])} instead.",
stacklevel=find_stack_level(),
)
return True
return False
def evaluate(op, a, b, use_numexpr: bool = True):
"""
Evaluate and return the expression of the op on a and b.
Parameters
----------
op : the actual operand
a : left operand
b : right operand
use_numexpr : bool, default True
Whether to try to use numexpr.
"""
op_str = _op_str_mapping[op]
if op_str is not None:
if use_numexpr:
# error: "None" not callable
return _evaluate(op, op_str, a, b) # type: ignore[misc]
return _evaluate_standard(op, op_str, a, b)
def where(cond, a, b, use_numexpr: bool = True):
"""
Evaluate the where condition cond on a and b.
Parameters
----------
cond : np.ndarray[bool]
a : return if cond is True
b : return if cond is False
use_numexpr : bool, default True
Whether to try to use numexpr.
"""
assert _where is not None
return _where(cond, a, b) if use_numexpr else _where_standard(cond, a, b)
def set_test_mode(v: bool = True) -> None:
"""
Keeps track of whether numexpr was used.
Stores an additional ``True`` for every successful use of evaluate with
numexpr since the last ``get_test_result``.
"""
global _TEST_MODE, _TEST_RESULT
_TEST_MODE = v
_TEST_RESULT = []
def _store_test_result(used_numexpr: bool) -> None:
if used_numexpr:
_TEST_RESULT.append(used_numexpr)
def get_test_result() -> list[bool]:
"""
Get test result and reset test_results.
"""
global _TEST_RESULT
res = _TEST_RESULT
_TEST_RESULT = []
return res