206 lines
6.8 KiB
Python
206 lines
6.8 KiB
Python
|
import operator
|
||
|
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
from pandas._libs.tslibs import IncompatibleFrequency
|
||
|
|
||
|
import pandas as pd
|
||
|
from pandas import Series
|
||
|
import pandas._testing as tm
|
||
|
|
||
|
|
||
|
def _permute(obj):
|
||
|
return obj.take(np.random.permutation(len(obj)))
|
||
|
|
||
|
|
||
|
class TestSeriesFlexArithmetic:
|
||
|
@pytest.mark.parametrize(
|
||
|
"ts",
|
||
|
[
|
||
|
(lambda x: x, lambda x: x * 2, False),
|
||
|
(lambda x: x, lambda x: x[::2], False),
|
||
|
(lambda x: x, lambda x: 5, True),
|
||
|
(lambda x: tm.makeFloatSeries(), lambda x: tm.makeFloatSeries(), True),
|
||
|
],
|
||
|
)
|
||
|
@pytest.mark.parametrize(
|
||
|
"opname", ["add", "sub", "mul", "floordiv", "truediv", "pow"]
|
||
|
)
|
||
|
def test_flex_method_equivalence(self, opname, ts):
|
||
|
# check that Series.{opname} behaves like Series.__{opname}__,
|
||
|
tser = tm.makeTimeSeries().rename("ts")
|
||
|
|
||
|
series = ts[0](tser)
|
||
|
other = ts[1](tser)
|
||
|
check_reverse = ts[2]
|
||
|
|
||
|
op = getattr(Series, opname)
|
||
|
alt = getattr(operator, opname)
|
||
|
|
||
|
result = op(series, other)
|
||
|
expected = alt(series, other)
|
||
|
tm.assert_almost_equal(result, expected)
|
||
|
if check_reverse:
|
||
|
rop = getattr(Series, "r" + opname)
|
||
|
result = rop(series, other)
|
||
|
expected = alt(other, series)
|
||
|
tm.assert_almost_equal(result, expected)
|
||
|
|
||
|
def test_flex_method_subclass_metadata_preservation(self, all_arithmetic_operators):
|
||
|
# GH 13208
|
||
|
class MySeries(Series):
|
||
|
_metadata = ["x"]
|
||
|
|
||
|
@property
|
||
|
def _constructor(self):
|
||
|
return MySeries
|
||
|
|
||
|
opname = all_arithmetic_operators
|
||
|
op = getattr(Series, opname)
|
||
|
m = MySeries([1, 2, 3], name="test")
|
||
|
m.x = 42
|
||
|
result = op(m, 1)
|
||
|
assert result.x == 42
|
||
|
|
||
|
|
||
|
class TestSeriesArithmetic:
|
||
|
# Some of these may end up in tests/arithmetic, but are not yet sorted
|
||
|
|
||
|
def test_add_series_with_period_index(self):
|
||
|
rng = pd.period_range("1/1/2000", "1/1/2010", freq="A")
|
||
|
ts = Series(np.random.randn(len(rng)), index=rng)
|
||
|
|
||
|
result = ts + ts[::2]
|
||
|
expected = ts + ts
|
||
|
expected[1::2] = np.nan
|
||
|
tm.assert_series_equal(result, expected)
|
||
|
|
||
|
result = ts + _permute(ts[::2])
|
||
|
tm.assert_series_equal(result, expected)
|
||
|
|
||
|
msg = "Input has different freq=D from PeriodIndex\\(freq=A-DEC\\)"
|
||
|
with pytest.raises(IncompatibleFrequency, match=msg):
|
||
|
ts + ts.asfreq("D", how="end")
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"target_add,input_value,expected_value",
|
||
|
[
|
||
|
("!", ["hello", "world"], ["hello!", "world!"]),
|
||
|
("m", ["hello", "world"], ["hellom", "worldm"]),
|
||
|
],
|
||
|
)
|
||
|
def test_string_addition(self, target_add, input_value, expected_value):
|
||
|
# GH28658 - ensure adding 'm' does not raise an error
|
||
|
a = Series(input_value)
|
||
|
|
||
|
result = a + target_add
|
||
|
expected = Series(expected_value)
|
||
|
tm.assert_series_equal(result, expected)
|
||
|
|
||
|
|
||
|
# ------------------------------------------------------------------
|
||
|
# Comparisons
|
||
|
|
||
|
|
||
|
class TestSeriesFlexComparison:
|
||
|
def test_comparison_flex_basic(self):
|
||
|
left = pd.Series(np.random.randn(10))
|
||
|
right = pd.Series(np.random.randn(10))
|
||
|
|
||
|
tm.assert_series_equal(left.eq(right), left == right)
|
||
|
tm.assert_series_equal(left.ne(right), left != right)
|
||
|
tm.assert_series_equal(left.le(right), left < right)
|
||
|
tm.assert_series_equal(left.lt(right), left <= right)
|
||
|
tm.assert_series_equal(left.gt(right), left > right)
|
||
|
tm.assert_series_equal(left.ge(right), left >= right)
|
||
|
|
||
|
# axis
|
||
|
for axis in [0, None, "index"]:
|
||
|
tm.assert_series_equal(left.eq(right, axis=axis), left == right)
|
||
|
tm.assert_series_equal(left.ne(right, axis=axis), left != right)
|
||
|
tm.assert_series_equal(left.le(right, axis=axis), left < right)
|
||
|
tm.assert_series_equal(left.lt(right, axis=axis), left <= right)
|
||
|
tm.assert_series_equal(left.gt(right, axis=axis), left > right)
|
||
|
tm.assert_series_equal(left.ge(right, axis=axis), left >= right)
|
||
|
|
||
|
#
|
||
|
msg = "No axis named 1 for object type"
|
||
|
for op in ["eq", "ne", "le", "le", "gt", "ge"]:
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
getattr(left, op)(right, axis=1)
|
||
|
|
||
|
|
||
|
class TestSeriesComparison:
|
||
|
def test_comparison_different_length(self):
|
||
|
a = Series(["a", "b", "c"])
|
||
|
b = Series(["b", "a"])
|
||
|
with pytest.raises(ValueError):
|
||
|
a < b
|
||
|
|
||
|
a = Series([1, 2])
|
||
|
b = Series([2, 3, 4])
|
||
|
with pytest.raises(ValueError):
|
||
|
a == b
|
||
|
|
||
|
@pytest.mark.parametrize("opname", ["eq", "ne", "gt", "lt", "ge", "le"])
|
||
|
def test_ser_flex_cmp_return_dtypes(self, opname):
|
||
|
# GH#15115
|
||
|
ser = Series([1, 3, 2], index=range(3))
|
||
|
const = 2
|
||
|
result = getattr(ser, opname)(const).dtypes
|
||
|
expected = np.dtype("bool")
|
||
|
assert result == expected
|
||
|
|
||
|
@pytest.mark.parametrize("opname", ["eq", "ne", "gt", "lt", "ge", "le"])
|
||
|
def test_ser_flex_cmp_return_dtypes_empty(self, opname):
|
||
|
# GH#15115 empty Series case
|
||
|
ser = Series([1, 3, 2], index=range(3))
|
||
|
empty = ser.iloc[:0]
|
||
|
const = 2
|
||
|
result = getattr(empty, opname)(const).dtypes
|
||
|
expected = np.dtype("bool")
|
||
|
assert result == expected
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"op",
|
||
|
[operator.eq, operator.ne, operator.le, operator.lt, operator.ge, operator.gt],
|
||
|
)
|
||
|
@pytest.mark.parametrize(
|
||
|
"names", [(None, None, None), ("foo", "bar", None), ("baz", "baz", "baz")]
|
||
|
)
|
||
|
def test_ser_cmp_result_names(self, names, op):
|
||
|
# datetime64 dtype
|
||
|
dti = pd.date_range("1949-06-07 03:00:00", freq="H", periods=5, name=names[0])
|
||
|
ser = Series(dti).rename(names[1])
|
||
|
result = op(ser, dti)
|
||
|
assert result.name == names[2]
|
||
|
|
||
|
# datetime64tz dtype
|
||
|
dti = dti.tz_localize("US/Central")
|
||
|
ser = Series(dti).rename(names[1])
|
||
|
result = op(ser, dti)
|
||
|
assert result.name == names[2]
|
||
|
|
||
|
# timedelta64 dtype
|
||
|
tdi = dti - dti.shift(1)
|
||
|
ser = Series(tdi).rename(names[1])
|
||
|
result = op(ser, tdi)
|
||
|
assert result.name == names[2]
|
||
|
|
||
|
# interval dtype
|
||
|
if op in [operator.eq, operator.ne]:
|
||
|
# interval dtype comparisons not yet implemented
|
||
|
ii = pd.interval_range(start=0, periods=5, name=names[0])
|
||
|
ser = Series(ii).rename(names[1])
|
||
|
result = op(ser, ii)
|
||
|
assert result.name == names[2]
|
||
|
|
||
|
# categorical
|
||
|
if op in [operator.eq, operator.ne]:
|
||
|
# categorical dtype comparisons raise for inequalities
|
||
|
cidx = tdi.astype("category")
|
||
|
ser = Series(cidx).rename(names[1])
|
||
|
result = op(ser, cidx)
|
||
|
assert result.name == names[2]
|