from __future__ import annotations import numpy as np import pytest import pandas as pd import pandas._testing as tm from pandas.core import ops from pandas.tests.extension.base.base import BaseExtensionTests class BaseOpsUtil(BaseExtensionTests): def get_op_from_name(self, op_name: str): return tm.get_op_from_name(op_name) def check_opname(self, ser: pd.Series, op_name: str, other, exc=Exception): op = self.get_op_from_name(op_name) self._check_op(ser, op, other, op_name, exc) def _combine(self, obj, other, op): if isinstance(obj, pd.DataFrame): if len(obj.columns) != 1: raise NotImplementedError expected = obj.iloc[:, 0].combine(other, op).to_frame() else: expected = obj.combine(other, op) return expected def _check_op( self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError ): if exc is None: result = op(ser, other) expected = self._combine(ser, other, op) assert isinstance(result, type(ser)) self.assert_equal(result, expected) else: with pytest.raises(exc): op(ser, other) def _check_divmod_op(self, ser: pd.Series, op, other, exc=Exception): # divmod has multiple return values, so check separately if exc is None: result_div, result_mod = op(ser, other) if op is divmod: expected_div, expected_mod = ser // other, ser % other else: expected_div, expected_mod = other // ser, other % ser self.assert_series_equal(result_div, expected_div) self.assert_series_equal(result_mod, expected_mod) else: with pytest.raises(exc): divmod(ser, other) class BaseArithmeticOpsTests(BaseOpsUtil): """ Various Series and DataFrame arithmetic ops methods. Subclasses supporting various ops should set the class variables to indicate that they support ops of that kind * series_scalar_exc = TypeError * frame_scalar_exc = TypeError * series_array_exc = TypeError * divmod_exc = TypeError """ series_scalar_exc: type[Exception] | None = TypeError frame_scalar_exc: type[Exception] | None = TypeError series_array_exc: type[Exception] | None = TypeError divmod_exc: type[Exception] | None = TypeError def test_arith_series_with_scalar(self, data, all_arithmetic_operators): # series & scalar op_name = all_arithmetic_operators ser = pd.Series(data) self.check_opname(ser, op_name, ser.iloc[0], exc=self.series_scalar_exc) def test_arith_frame_with_scalar(self, data, all_arithmetic_operators): # frame & scalar op_name = all_arithmetic_operators df = pd.DataFrame({"A": data}) self.check_opname(df, op_name, data[0], exc=self.frame_scalar_exc) def test_arith_series_with_array(self, data, all_arithmetic_operators): # ndarray & other series op_name = all_arithmetic_operators ser = pd.Series(data) self.check_opname( ser, op_name, pd.Series([ser.iloc[0]] * len(ser)), exc=self.series_array_exc ) def test_divmod(self, data): ser = pd.Series(data) self._check_divmod_op(ser, divmod, 1, exc=self.divmod_exc) self._check_divmod_op(1, ops.rdivmod, ser, exc=self.divmod_exc) def test_divmod_series_array(self, data, data_for_twos): ser = pd.Series(data) self._check_divmod_op(ser, divmod, data) other = data_for_twos self._check_divmod_op(other, ops.rdivmod, ser) other = pd.Series(other) self._check_divmod_op(other, ops.rdivmod, ser) def test_add_series_with_extension_array(self, data): ser = pd.Series(data) result = ser + data expected = pd.Series(data + data) self.assert_series_equal(result, expected) @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame]) def test_direct_arith_with_ndframe_returns_not_implemented( self, request, data, box ): # EAs should return NotImplemented for ops with Series/DataFrame # Pandas takes care of unboxing the series and calling the EA's op. other = pd.Series(data) if box is pd.DataFrame: other = other.to_frame() if not hasattr(data, "__add__"): request.node.add_marker( pytest.mark.xfail( reason=f"{type(data).__name__} does not implement add" ) ) result = data.__add__(other) assert result is NotImplemented class BaseComparisonOpsTests(BaseOpsUtil): """Various Series and DataFrame comparison ops methods.""" def _compare_other(self, ser: pd.Series, data, op, other): if op.__name__ in ["eq", "ne"]: # comparison should match point-wise comparisons result = op(ser, other) expected = ser.combine(other, op) self.assert_series_equal(result, expected) else: exc = None try: result = op(ser, other) except Exception as err: exc = err if exc is None: # Didn't error, then should match pointwise behavior expected = ser.combine(other, op) self.assert_series_equal(result, expected) else: with pytest.raises(type(exc)): ser.combine(other, op) def test_compare_scalar(self, data, comparison_op): ser = pd.Series(data) self._compare_other(ser, data, comparison_op, 0) def test_compare_array(self, data, comparison_op): ser = pd.Series(data) other = pd.Series([data[0]] * len(data)) self._compare_other(ser, data, comparison_op, other) @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame]) def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box): # EAs should return NotImplemented for ops with Series/DataFrame # Pandas takes care of unboxing the series and calling the EA's op. other = pd.Series(data) if box is pd.DataFrame: other = other.to_frame() if hasattr(data, "__eq__"): result = data.__eq__(other) assert result is NotImplemented else: pytest.skip(f"{type(data).__name__} does not implement __eq__") if hasattr(data, "__ne__"): result = data.__ne__(other) assert result is NotImplemented else: pytest.skip(f"{type(data).__name__} does not implement __ne__") class BaseUnaryOpsTests(BaseOpsUtil): def test_invert(self, data): ser = pd.Series(data, name="name") result = ~ser expected = pd.Series(~data, name="name") self.assert_series_equal(result, expected) @pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs]) def test_unary_ufunc_dunder_equivalence(self, data, ufunc): # the dunder __pos__ works if and only if np.positive works, # same for __neg__/np.negative and __abs__/np.abs attr = {np.positive: "__pos__", np.negative: "__neg__", np.abs: "__abs__"}[ ufunc ] exc = None try: result = getattr(data, attr)() except Exception as err: exc = err # if __pos__ raised, then so should the ufunc with pytest.raises((type(exc), TypeError)): ufunc(data) else: alt = ufunc(data) self.assert_extension_array_equal(result, alt)