2720 lines
92 KiB
Python
2720 lines
92 KiB
Python
"""
|
||
This file contains a minimal set of tests for compliance with the extension
|
||
array interface test suite, and should contain no other tests.
|
||
The test suite for the full functionality of the array is located in
|
||
`pandas/tests/arrays/`.
|
||
The tests in this file are inherited from the BaseExtensionTests, and only
|
||
minimal tweaks should be applied to get the tests passing (by overwriting a
|
||
parent method).
|
||
Additional tests should either be added to one of the BaseExtensionTests
|
||
classes (if they are relevant for the extension interface for all dtypes), or
|
||
be added to the array-specific tests in `pandas/tests/arrays/`.
|
||
"""
|
||
from datetime import (
|
||
date,
|
||
datetime,
|
||
time,
|
||
timedelta,
|
||
)
|
||
from decimal import Decimal
|
||
from io import (
|
||
BytesIO,
|
||
StringIO,
|
||
)
|
||
import operator
|
||
import pickle
|
||
import re
|
||
|
||
import numpy as np
|
||
import pytest
|
||
|
||
from pandas._libs import lib
|
||
from pandas.compat import (
|
||
PY311,
|
||
is_ci_environment,
|
||
is_platform_windows,
|
||
pa_version_under7p0,
|
||
pa_version_under8p0,
|
||
pa_version_under9p0,
|
||
pa_version_under11p0,
|
||
)
|
||
from pandas.errors import PerformanceWarning
|
||
|
||
from pandas.core.dtypes.common import is_any_int_dtype
|
||
from pandas.core.dtypes.dtypes import CategoricalDtypeType
|
||
|
||
import pandas as pd
|
||
import pandas._testing as tm
|
||
from pandas.api.types import (
|
||
is_bool_dtype,
|
||
is_float_dtype,
|
||
is_integer_dtype,
|
||
is_numeric_dtype,
|
||
is_signed_integer_dtype,
|
||
is_string_dtype,
|
||
is_unsigned_integer_dtype,
|
||
)
|
||
from pandas.tests.extension import base
|
||
|
||
pa = pytest.importorskip("pyarrow", minversion="7.0.0")
|
||
|
||
from pandas.core.arrays.arrow.array import ArrowExtensionArray
|
||
|
||
from pandas.core.arrays.arrow.dtype import ArrowDtype # isort:skip
|
||
|
||
|
||
@pytest.fixture(params=tm.ALL_PYARROW_DTYPES, ids=str)
|
||
def dtype(request):
|
||
return ArrowDtype(pyarrow_dtype=request.param)
|
||
|
||
|
||
@pytest.fixture
|
||
def data(dtype):
|
||
pa_dtype = dtype.pyarrow_dtype
|
||
if pa.types.is_boolean(pa_dtype):
|
||
data = [True, False] * 4 + [None] + [True, False] * 44 + [None] + [True, False]
|
||
elif pa.types.is_floating(pa_dtype):
|
||
data = [1.0, 0.0] * 4 + [None] + [-2.0, -1.0] * 44 + [None] + [0.5, 99.5]
|
||
elif pa.types.is_signed_integer(pa_dtype):
|
||
data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99]
|
||
elif pa.types.is_unsigned_integer(pa_dtype):
|
||
data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99]
|
||
elif pa.types.is_decimal(pa_dtype):
|
||
data = (
|
||
[Decimal("1"), Decimal("0.0")] * 4
|
||
+ [None]
|
||
+ [Decimal("-2.0"), Decimal("-1.0")] * 44
|
||
+ [None]
|
||
+ [Decimal("0.5"), Decimal("33.123")]
|
||
)
|
||
elif pa.types.is_date(pa_dtype):
|
||
data = (
|
||
[date(2022, 1, 1), date(1999, 12, 31)] * 4
|
||
+ [None]
|
||
+ [date(2022, 1, 1), date(2022, 1, 1)] * 44
|
||
+ [None]
|
||
+ [date(1999, 12, 31), date(1999, 12, 31)]
|
||
)
|
||
elif pa.types.is_timestamp(pa_dtype):
|
||
data = (
|
||
[datetime(2020, 1, 1, 1, 1, 1, 1), datetime(1999, 1, 1, 1, 1, 1, 1)] * 4
|
||
+ [None]
|
||
+ [datetime(2020, 1, 1, 1), datetime(1999, 1, 1, 1)] * 44
|
||
+ [None]
|
||
+ [datetime(2020, 1, 1), datetime(1999, 1, 1)]
|
||
)
|
||
elif pa.types.is_duration(pa_dtype):
|
||
data = (
|
||
[timedelta(1), timedelta(1, 1)] * 4
|
||
+ [None]
|
||
+ [timedelta(-1), timedelta(0)] * 44
|
||
+ [None]
|
||
+ [timedelta(-10), timedelta(10)]
|
||
)
|
||
elif pa.types.is_time(pa_dtype):
|
||
data = (
|
||
[time(12, 0), time(0, 12)] * 4
|
||
+ [None]
|
||
+ [time(0, 0), time(1, 1)] * 44
|
||
+ [None]
|
||
+ [time(0, 5), time(5, 0)]
|
||
)
|
||
elif pa.types.is_string(pa_dtype):
|
||
data = ["a", "b"] * 4 + [None] + ["1", "2"] * 44 + [None] + ["!", ">"]
|
||
elif pa.types.is_binary(pa_dtype):
|
||
data = [b"a", b"b"] * 4 + [None] + [b"1", b"2"] * 44 + [None] + [b"!", b">"]
|
||
else:
|
||
raise NotImplementedError
|
||
return pd.array(data, dtype=dtype)
|
||
|
||
|
||
@pytest.fixture
|
||
def data_missing(data):
|
||
"""Length-2 array with [NA, Valid]"""
|
||
return type(data)._from_sequence([None, data[0]], dtype=data.dtype)
|
||
|
||
|
||
@pytest.fixture(params=["data", "data_missing"])
|
||
def all_data(request, data, data_missing):
|
||
"""Parametrized fixture returning 'data' or 'data_missing' integer arrays.
|
||
|
||
Used to test dtype conversion with and without missing values.
|
||
"""
|
||
if request.param == "data":
|
||
return data
|
||
elif request.param == "data_missing":
|
||
return data_missing
|
||
|
||
|
||
@pytest.fixture
|
||
def data_for_grouping(dtype):
|
||
"""
|
||
Data for factorization, grouping, and unique tests.
|
||
|
||
Expected to be like [B, B, NA, NA, A, A, B, C]
|
||
|
||
Where A < B < C and NA is missing
|
||
"""
|
||
pa_dtype = dtype.pyarrow_dtype
|
||
if pa.types.is_boolean(pa_dtype):
|
||
A = False
|
||
B = True
|
||
C = True
|
||
elif pa.types.is_floating(pa_dtype):
|
||
A = -1.1
|
||
B = 0.0
|
||
C = 1.1
|
||
elif pa.types.is_signed_integer(pa_dtype):
|
||
A = -1
|
||
B = 0
|
||
C = 1
|
||
elif pa.types.is_unsigned_integer(pa_dtype):
|
||
A = 0
|
||
B = 1
|
||
C = 10
|
||
elif pa.types.is_date(pa_dtype):
|
||
A = date(1999, 12, 31)
|
||
B = date(2010, 1, 1)
|
||
C = date(2022, 1, 1)
|
||
elif pa.types.is_timestamp(pa_dtype):
|
||
A = datetime(1999, 1, 1, 1, 1, 1, 1)
|
||
B = datetime(2020, 1, 1)
|
||
C = datetime(2020, 1, 1, 1)
|
||
elif pa.types.is_duration(pa_dtype):
|
||
A = timedelta(-1)
|
||
B = timedelta(0)
|
||
C = timedelta(1, 4)
|
||
elif pa.types.is_time(pa_dtype):
|
||
A = time(0, 0)
|
||
B = time(0, 12)
|
||
C = time(12, 12)
|
||
elif pa.types.is_string(pa_dtype):
|
||
A = "a"
|
||
B = "b"
|
||
C = "c"
|
||
elif pa.types.is_binary(pa_dtype):
|
||
A = b"a"
|
||
B = b"b"
|
||
C = b"c"
|
||
elif pa.types.is_decimal(pa_dtype):
|
||
A = Decimal("-1.1")
|
||
B = Decimal("0.0")
|
||
C = Decimal("1.1")
|
||
else:
|
||
raise NotImplementedError
|
||
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
|
||
|
||
|
||
@pytest.fixture
|
||
def data_for_sorting(data_for_grouping):
|
||
"""
|
||
Length-3 array with a known sort order.
|
||
|
||
This should be three items [B, C, A] with
|
||
A < B < C
|
||
"""
|
||
return type(data_for_grouping)._from_sequence(
|
||
[data_for_grouping[0], data_for_grouping[7], data_for_grouping[4]],
|
||
dtype=data_for_grouping.dtype,
|
||
)
|
||
|
||
|
||
@pytest.fixture
|
||
def data_missing_for_sorting(data_for_grouping):
|
||
"""
|
||
Length-3 array with a known sort order.
|
||
|
||
This should be three items [B, NA, A] with
|
||
A < B and NA missing.
|
||
"""
|
||
return type(data_for_grouping)._from_sequence(
|
||
[data_for_grouping[0], data_for_grouping[2], data_for_grouping[4]],
|
||
dtype=data_for_grouping.dtype,
|
||
)
|
||
|
||
|
||
@pytest.fixture
|
||
def data_for_twos(data):
|
||
"""Length-100 array in which all the elements are two."""
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
|
||
return pd.array([2] * 100, dtype=data.dtype)
|
||
# tests will be xfailed where 2 is not a valid scalar for pa_dtype
|
||
return data
|
||
|
||
|
||
@pytest.fixture
|
||
def na_value():
|
||
"""The scalar missing value for this type. Default 'None'"""
|
||
return pd.NA
|
||
|
||
|
||
class TestBaseCasting(base.BaseCastingTests):
|
||
def test_astype_str(self, data, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
if pa.types.is_binary(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"For {pa_dtype} .astype(str) decodes.",
|
||
)
|
||
)
|
||
super().test_astype_str(data)
|
||
|
||
|
||
class TestConstructors(base.BaseConstructorsTests):
|
||
def test_from_dtype(self, data, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
|
||
if pa.types.is_string(pa_dtype):
|
||
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
|
||
else:
|
||
reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}"
|
||
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=reason,
|
||
)
|
||
)
|
||
super().test_from_dtype(data)
|
||
|
||
def test_from_sequence_pa_array(self, data):
|
||
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
|
||
# data._data = pa.ChunkedArray
|
||
result = type(data)._from_sequence(data._data)
|
||
tm.assert_extension_array_equal(result, data)
|
||
assert isinstance(result._data, pa.ChunkedArray)
|
||
|
||
result = type(data)._from_sequence(data._data.combine_chunks())
|
||
tm.assert_extension_array_equal(result, data)
|
||
assert isinstance(result._data, pa.ChunkedArray)
|
||
|
||
def test_from_sequence_pa_array_notimplemented(self, request):
|
||
with pytest.raises(NotImplementedError, match="Converting strings to"):
|
||
ArrowExtensionArray._from_sequence_of_strings(
|
||
["12-1"], dtype=pa.month_day_nano_interval()
|
||
)
|
||
|
||
def test_from_sequence_of_strings_pa_array(self, data, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
if pa.types.is_time64(pa_dtype) and pa_dtype.equals("time64[ns]") and not PY311:
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason="Nanosecond time parsing not supported.",
|
||
)
|
||
)
|
||
elif pa_version_under11p0 and (
|
||
pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype)
|
||
):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowNotImplementedError,
|
||
reason=f"pyarrow doesn't support parsing {pa_dtype}",
|
||
)
|
||
)
|
||
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
|
||
if is_platform_windows() and is_ci_environment():
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason=(
|
||
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
|
||
"on CI to path to the tzdata for pyarrow."
|
||
),
|
||
)
|
||
)
|
||
pa_array = data._data.cast(pa.string())
|
||
result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
|
||
tm.assert_extension_array_equal(result, data)
|
||
|
||
pa_array = pa_array.combine_chunks()
|
||
result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
|
||
tm.assert_extension_array_equal(result, data)
|
||
|
||
|
||
class TestGetitemTests(base.BaseGetitemTests):
|
||
pass
|
||
|
||
|
||
class TestBaseAccumulateTests(base.BaseAccumulateTests):
|
||
def check_accumulate(self, ser, op_name, skipna):
|
||
result = getattr(ser, op_name)(skipna=skipna)
|
||
|
||
if ser.dtype.kind == "m":
|
||
# Just check that we match the integer behavior.
|
||
ser = ser.astype("int64[pyarrow]")
|
||
result = result.astype("int64[pyarrow]")
|
||
|
||
result = result.astype("Float64")
|
||
expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna)
|
||
self.assert_series_equal(result, expected, check_dtype=False)
|
||
|
||
@pytest.mark.parametrize("skipna", [True, False])
|
||
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
|
||
pa_type = data.dtype.pyarrow_dtype
|
||
if (
|
||
(
|
||
pa.types.is_integer(pa_type)
|
||
or pa.types.is_floating(pa_type)
|
||
or pa.types.is_duration(pa_type)
|
||
)
|
||
and all_numeric_accumulations == "cumsum"
|
||
and not pa_version_under9p0
|
||
):
|
||
pytest.skip("These work, are tested by test_accumulate_series.")
|
||
|
||
op_name = all_numeric_accumulations
|
||
ser = pd.Series(data)
|
||
|
||
with pytest.raises(NotImplementedError):
|
||
getattr(ser, op_name)(skipna=skipna)
|
||
|
||
@pytest.mark.parametrize("skipna", [True, False])
|
||
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
|
||
pa_type = data.dtype.pyarrow_dtype
|
||
op_name = all_numeric_accumulations
|
||
ser = pd.Series(data)
|
||
|
||
do_skip = False
|
||
if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
|
||
if op_name in ["cumsum", "cumprod"]:
|
||
do_skip = True
|
||
elif pa.types.is_temporal(pa_type) and not pa.types.is_duration(pa_type):
|
||
if op_name in ["cumsum", "cumprod"]:
|
||
do_skip = True
|
||
elif pa.types.is_duration(pa_type):
|
||
if op_name == "cumprod":
|
||
do_skip = True
|
||
|
||
if do_skip:
|
||
pytest.skip(
|
||
"These should *not* work, we test in test_accumulate_series_raises "
|
||
"that these correctly raise."
|
||
)
|
||
|
||
if all_numeric_accumulations != "cumsum" or pa_version_under9p0:
|
||
if request.config.option.skip_slow:
|
||
# equivalent to marking these cases with @pytest.mark.slow,
|
||
# these xfails take a long time to run because pytest
|
||
# renders the exception messages even when not showing them
|
||
pytest.skip("pyarrow xfail slow")
|
||
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"{all_numeric_accumulations} not implemented",
|
||
raises=NotImplementedError,
|
||
)
|
||
)
|
||
elif all_numeric_accumulations == "cumsum" and (
|
||
pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type)
|
||
):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"{all_numeric_accumulations} not implemented for {pa_type}",
|
||
raises=NotImplementedError,
|
||
)
|
||
)
|
||
|
||
self.check_accumulate(ser, op_name, skipna)
|
||
|
||
|
||
class TestBaseNumericReduce(base.BaseNumericReduceTests):
|
||
def check_reduce(self, ser, op_name, skipna):
|
||
pa_dtype = ser.dtype.pyarrow_dtype
|
||
if op_name == "count":
|
||
result = getattr(ser, op_name)()
|
||
else:
|
||
result = getattr(ser, op_name)(skipna=skipna)
|
||
if pa.types.is_boolean(pa_dtype):
|
||
# Can't convert if ser contains NA
|
||
pytest.skip(
|
||
"pandas boolean data with NA does not fully support all reductions"
|
||
)
|
||
elif pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
|
||
ser = ser.astype("Float64")
|
||
if op_name == "count":
|
||
expected = getattr(ser, op_name)()
|
||
else:
|
||
expected = getattr(ser, op_name)(skipna=skipna)
|
||
tm.assert_almost_equal(result, expected)
|
||
|
||
@pytest.mark.parametrize("skipna", [True, False])
|
||
def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
opname = all_numeric_reductions
|
||
|
||
ser = pd.Series(data)
|
||
|
||
should_work = True
|
||
if pa.types.is_temporal(pa_dtype) and opname in [
|
||
"sum",
|
||
"var",
|
||
"skew",
|
||
"kurt",
|
||
"prod",
|
||
]:
|
||
if pa.types.is_duration(pa_dtype) and opname in ["sum"]:
|
||
# summing timedeltas is one case that *is* well-defined
|
||
pass
|
||
else:
|
||
should_work = False
|
||
elif (
|
||
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
|
||
) and opname in [
|
||
"sum",
|
||
"mean",
|
||
"median",
|
||
"prod",
|
||
"std",
|
||
"sem",
|
||
"var",
|
||
"skew",
|
||
"kurt",
|
||
]:
|
||
should_work = False
|
||
|
||
if not should_work:
|
||
# matching the non-pyarrow versions, these operations *should* not
|
||
# work for these dtypes
|
||
msg = f"does not support reduction '{opname}'"
|
||
with pytest.raises(TypeError, match=msg):
|
||
getattr(ser, opname)(skipna=skipna)
|
||
|
||
return
|
||
|
||
xfail_mark = pytest.mark.xfail(
|
||
raises=TypeError,
|
||
reason=(
|
||
f"{all_numeric_reductions} is not implemented in "
|
||
f"pyarrow={pa.__version__} for {pa_dtype}"
|
||
),
|
||
)
|
||
if all_numeric_reductions in {"skew", "kurt"}:
|
||
request.node.add_marker(xfail_mark)
|
||
elif (
|
||
all_numeric_reductions in {"var", "std", "median"}
|
||
and pa_version_under7p0
|
||
and pa.types.is_decimal(pa_dtype)
|
||
):
|
||
request.node.add_marker(xfail_mark)
|
||
elif all_numeric_reductions == "sem" and pa_version_under8p0:
|
||
request.node.add_marker(xfail_mark)
|
||
|
||
elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in {
|
||
"sem",
|
||
"std",
|
||
"var",
|
||
"median",
|
||
}:
|
||
request.node.add_marker(xfail_mark)
|
||
super().test_reduce_series(data, all_numeric_reductions, skipna)
|
||
|
||
@pytest.mark.parametrize("typ", ["int64", "uint64", "float64"])
|
||
def test_median_not_approximate(self, typ):
|
||
# GH 52679
|
||
result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median()
|
||
assert result == 1.5
|
||
|
||
|
||
class TestBaseBooleanReduce(base.BaseBooleanReduceTests):
|
||
@pytest.mark.parametrize("skipna", [True, False])
|
||
def test_reduce_series(
|
||
self, data, all_boolean_reductions, skipna, na_value, request
|
||
):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
xfail_mark = pytest.mark.xfail(
|
||
raises=TypeError,
|
||
reason=(
|
||
f"{all_boolean_reductions} is not implemented in "
|
||
f"pyarrow={pa.__version__} for {pa_dtype}"
|
||
),
|
||
)
|
||
if pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype):
|
||
# We *might* want to make this behave like the non-pyarrow cases,
|
||
# but have not yet decided.
|
||
request.node.add_marker(xfail_mark)
|
||
|
||
op_name = all_boolean_reductions
|
||
ser = pd.Series(data)
|
||
|
||
if pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype):
|
||
# xref GH#34479 we support this in our non-pyarrow datetime64 dtypes,
|
||
# but it isn't obvious we _should_. For now, we keep the pyarrow
|
||
# behavior which does not support this.
|
||
|
||
with pytest.raises(TypeError, match="does not support reduction"):
|
||
getattr(ser, op_name)(skipna=skipna)
|
||
|
||
return
|
||
|
||
result = getattr(ser, op_name)(skipna=skipna)
|
||
assert result is (op_name == "any")
|
||
|
||
|
||
class TestBaseGroupby(base.BaseGroupbyTests):
|
||
def test_groupby_extension_no_sort(self, data_for_grouping, request):
|
||
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
|
||
if pa.types.is_boolean(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"{pa_dtype} only has 2 unique possible values",
|
||
)
|
||
)
|
||
super().test_groupby_extension_no_sort(data_for_grouping)
|
||
|
||
def test_groupby_extension_transform(self, data_for_grouping, request):
|
||
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
|
||
if pa.types.is_boolean(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"{pa_dtype} only has 2 unique possible values",
|
||
)
|
||
)
|
||
super().test_groupby_extension_transform(data_for_grouping)
|
||
|
||
@pytest.mark.parametrize("as_index", [True, False])
|
||
def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
|
||
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
|
||
if pa.types.is_boolean(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=ValueError,
|
||
reason=f"{pa_dtype} only has 2 unique possible values",
|
||
)
|
||
)
|
||
super().test_groupby_extension_agg(as_index, data_for_grouping)
|
||
|
||
def test_in_numeric_groupby(self, data_for_grouping):
|
||
if is_string_dtype(data_for_grouping.dtype):
|
||
df = pd.DataFrame(
|
||
{
|
||
"A": [1, 1, 2, 2, 3, 3, 1, 4],
|
||
"B": data_for_grouping,
|
||
"C": [1, 1, 1, 1, 1, 1, 1, 1],
|
||
}
|
||
)
|
||
|
||
expected = pd.Index(["C"])
|
||
with pytest.raises(TypeError, match="does not support"):
|
||
df.groupby("A").sum().columns
|
||
result = df.groupby("A").sum(numeric_only=True).columns
|
||
tm.assert_index_equal(result, expected)
|
||
else:
|
||
super().test_in_numeric_groupby(data_for_grouping)
|
||
|
||
|
||
class TestBaseDtype(base.BaseDtypeTests):
|
||
def test_check_dtype(self, data, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
if pa.types.is_decimal(pa_dtype) and pa_version_under8p0:
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=ValueError,
|
||
reason="decimal string repr affects numpy comparison",
|
||
)
|
||
)
|
||
super().test_check_dtype(data)
|
||
|
||
def test_construct_from_string_own_name(self, dtype, request):
|
||
pa_dtype = dtype.pyarrow_dtype
|
||
if pa.types.is_decimal(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=NotImplementedError,
|
||
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
|
||
)
|
||
)
|
||
|
||
if pa.types.is_string(pa_dtype):
|
||
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
|
||
msg = r"string\[pyarrow\] should be constructed by StringDtype"
|
||
with pytest.raises(TypeError, match=msg):
|
||
dtype.construct_from_string(dtype.name)
|
||
|
||
return
|
||
|
||
super().test_construct_from_string_own_name(dtype)
|
||
|
||
def test_is_dtype_from_name(self, dtype, request):
|
||
pa_dtype = dtype.pyarrow_dtype
|
||
if pa.types.is_string(pa_dtype):
|
||
# We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
|
||
assert not type(dtype).is_dtype(dtype.name)
|
||
else:
|
||
if pa.types.is_decimal(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=NotImplementedError,
|
||
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
|
||
)
|
||
)
|
||
super().test_is_dtype_from_name(dtype)
|
||
|
||
def test_construct_from_string_another_type_raises(self, dtype):
|
||
msg = r"'another_type' must end with '\[pyarrow\]'"
|
||
with pytest.raises(TypeError, match=msg):
|
||
type(dtype).construct_from_string("another_type")
|
||
|
||
def test_get_common_dtype(self, dtype, request):
|
||
pa_dtype = dtype.pyarrow_dtype
|
||
if (
|
||
pa.types.is_date(pa_dtype)
|
||
or pa.types.is_time(pa_dtype)
|
||
or (
|
||
pa.types.is_timestamp(pa_dtype)
|
||
and (pa_dtype.unit != "ns" or pa_dtype.tz is not None)
|
||
)
|
||
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
|
||
or pa.types.is_binary(pa_dtype)
|
||
or pa.types.is_decimal(pa_dtype)
|
||
):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=(
|
||
f"{pa_dtype} does not have associated numpy "
|
||
f"dtype findable by find_common_type"
|
||
)
|
||
)
|
||
)
|
||
super().test_get_common_dtype(dtype)
|
||
|
||
def test_is_not_string_type(self, dtype):
|
||
pa_dtype = dtype.pyarrow_dtype
|
||
if pa.types.is_string(pa_dtype):
|
||
assert is_string_dtype(dtype)
|
||
else:
|
||
super().test_is_not_string_type(dtype)
|
||
|
||
|
||
class TestBaseIndex(base.BaseIndexTests):
|
||
pass
|
||
|
||
|
||
class TestBaseInterface(base.BaseInterfaceTests):
|
||
@pytest.mark.xfail(
|
||
reason="GH 45419: pyarrow.ChunkedArray does not support views.", run=False
|
||
)
|
||
def test_view(self, data):
|
||
super().test_view(data)
|
||
|
||
|
||
class TestBaseMissing(base.BaseMissingTests):
|
||
def test_fillna_no_op_returns_copy(self, data):
|
||
data = data[~data.isna()]
|
||
|
||
valid = data[0]
|
||
result = data.fillna(valid)
|
||
assert result is not data
|
||
self.assert_extension_array_equal(result, data)
|
||
with tm.assert_produces_warning(PerformanceWarning):
|
||
result = data.fillna(method="backfill")
|
||
assert result is not data
|
||
self.assert_extension_array_equal(result, data)
|
||
|
||
def test_fillna_series_method(self, data_missing, fillna_method):
|
||
with tm.maybe_produces_warning(
|
||
PerformanceWarning, fillna_method is not None, check_stacklevel=False
|
||
):
|
||
super().test_fillna_series_method(data_missing, fillna_method)
|
||
|
||
|
||
class TestBasePrinting(base.BasePrintingTests):
|
||
pass
|
||
|
||
|
||
class TestBaseReshaping(base.BaseReshapingTests):
|
||
@pytest.mark.xfail(
|
||
reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False
|
||
)
|
||
def test_transpose(self, data):
|
||
super().test_transpose(data)
|
||
|
||
|
||
class TestBaseSetitem(base.BaseSetitemTests):
|
||
@pytest.mark.xfail(
|
||
reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False
|
||
)
|
||
def test_setitem_preserves_views(self, data):
|
||
super().test_setitem_preserves_views(data)
|
||
|
||
|
||
class TestBaseParsing(base.BaseParsingTests):
|
||
@pytest.mark.parametrize("engine", ["c", "python"])
|
||
def test_EA_types(self, engine, data, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
if pa.types.is_boolean(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(raises=TypeError, reason="GH 47534")
|
||
)
|
||
elif pa.types.is_decimal(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=NotImplementedError,
|
||
reason=f"Parameterized types {pa_dtype} not supported.",
|
||
)
|
||
)
|
||
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=ValueError,
|
||
reason="https://github.com/pandas-dev/pandas/issues/49767",
|
||
)
|
||
)
|
||
elif pa.types.is_binary(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
|
||
)
|
||
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
|
||
csv_output = df.to_csv(index=False, na_rep=np.nan)
|
||
if pa.types.is_binary(pa_dtype):
|
||
csv_output = BytesIO(csv_output)
|
||
else:
|
||
csv_output = StringIO(csv_output)
|
||
result = pd.read_csv(
|
||
csv_output, dtype={"with_dtype": str(data.dtype)}, engine=engine
|
||
)
|
||
expected = df
|
||
self.assert_frame_equal(result, expected)
|
||
|
||
|
||
class TestBaseUnaryOps(base.BaseUnaryOpsTests):
|
||
def test_invert(self, data, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
if not pa.types.is_boolean(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowNotImplementedError,
|
||
reason=f"pyarrow.compute.invert does support {pa_dtype}",
|
||
)
|
||
)
|
||
super().test_invert(data)
|
||
|
||
|
||
class TestBaseMethods(base.BaseMethodsTests):
|
||
@pytest.mark.parametrize("periods", [1, -2])
|
||
def test_diff(self, data, periods, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
if pa.types.is_unsigned_integer(pa_dtype) and periods == 1:
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason=(
|
||
f"diff with {pa_dtype} and periods={periods} will overflow"
|
||
),
|
||
)
|
||
)
|
||
super().test_diff(data, periods)
|
||
|
||
def test_value_counts_returns_pyarrow_int64(self, data):
|
||
# GH 51462
|
||
data = data[:10]
|
||
result = data.value_counts()
|
||
assert result.dtype == ArrowDtype(pa.int64())
|
||
|
||
def test_value_counts_with_normalize(self, data, request):
|
||
data = data[:10].unique()
|
||
values = np.array(data[~data.isna()])
|
||
ser = pd.Series(data, dtype=data.dtype)
|
||
|
||
result = ser.value_counts(normalize=True).sort_index()
|
||
|
||
expected = pd.Series(
|
||
[1 / len(values)] * len(values), index=result.index, name="proportion"
|
||
)
|
||
expected = expected.astype("double[pyarrow]")
|
||
|
||
self.assert_series_equal(result, expected)
|
||
|
||
def test_argmin_argmax(
|
||
self, data_for_sorting, data_missing_for_sorting, na_value, request
|
||
):
|
||
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
|
||
if pa.types.is_boolean(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"{pa_dtype} only has 2 unique possible values",
|
||
)
|
||
)
|
||
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"No pyarrow kernel for {pa_dtype}",
|
||
raises=pa.ArrowNotImplementedError,
|
||
)
|
||
)
|
||
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
|
||
|
||
@pytest.mark.parametrize(
|
||
"op_name, skipna, expected",
|
||
[
|
||
("idxmax", True, 0),
|
||
("idxmin", True, 2),
|
||
("argmax", True, 0),
|
||
("argmin", True, 2),
|
||
("idxmax", False, np.nan),
|
||
("idxmin", False, np.nan),
|
||
("argmax", False, -1),
|
||
("argmin", False, -1),
|
||
],
|
||
)
|
||
def test_argreduce_series(
|
||
self, data_missing_for_sorting, op_name, skipna, expected, request
|
||
):
|
||
pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
|
||
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna:
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"No pyarrow kernel for {pa_dtype}",
|
||
raises=pa.ArrowNotImplementedError,
|
||
)
|
||
)
|
||
super().test_argreduce_series(
|
||
data_missing_for_sorting, op_name, skipna, expected
|
||
)
|
||
|
||
def test_factorize(self, data_for_grouping, request):
|
||
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
|
||
if pa.types.is_boolean(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"{pa_dtype} only has 2 unique possible values",
|
||
)
|
||
)
|
||
super().test_factorize(data_for_grouping)
|
||
|
||
_combine_le_expected_dtype = "bool[pyarrow]"
|
||
|
||
def test_combine_add(self, data_repeated, request):
|
||
pa_dtype = next(data_repeated(1)).dtype.pyarrow_dtype
|
||
if pa.types.is_duration(pa_dtype):
|
||
# TODO: this fails on the scalar addition constructing 'expected'
|
||
# but not in the actual 'combine' call, so may be salvage-able
|
||
mark = pytest.mark.xfail(
|
||
raises=TypeError,
|
||
reason=f"{pa_dtype} cannot be added to {pa_dtype}",
|
||
)
|
||
request.node.add_marker(mark)
|
||
super().test_combine_add(data_repeated)
|
||
|
||
elif pa.types.is_temporal(pa_dtype):
|
||
# analogous to datetime64, these cannot be added
|
||
orig_data1, orig_data2 = data_repeated(2)
|
||
s1 = pd.Series(orig_data1)
|
||
s2 = pd.Series(orig_data2)
|
||
with pytest.raises(TypeError):
|
||
s1.combine(s2, lambda x1, x2: x1 + x2)
|
||
|
||
else:
|
||
super().test_combine_add(data_repeated)
|
||
|
||
def test_searchsorted(self, data_for_sorting, as_series, request):
|
||
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
|
||
if pa.types.is_boolean(pa_dtype):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
reason=f"{pa_dtype} only has 2 unique possible values",
|
||
)
|
||
)
|
||
super().test_searchsorted(data_for_sorting, as_series)
|
||
|
||
def test_basic_equals(self, data):
|
||
# https://github.com/pandas-dev/pandas/issues/34660
|
||
assert pd.Series(data).equals(pd.Series(data))
|
||
|
||
|
||
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
|
||
divmod_exc = NotImplementedError
|
||
|
||
@classmethod
|
||
def assert_equal(cls, left, right, **kwargs):
|
||
if isinstance(left, pd.DataFrame):
|
||
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype
|
||
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype
|
||
else:
|
||
left_pa_type = left.dtype.pyarrow_dtype
|
||
right_pa_type = right.dtype.pyarrow_dtype
|
||
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type):
|
||
# decimal precision can resize in the result type depending on data
|
||
# just compare the float values
|
||
left = left.astype("float[pyarrow]")
|
||
right = right.astype("float[pyarrow]")
|
||
tm.assert_equal(left, right, **kwargs)
|
||
|
||
def get_op_from_name(self, op_name):
|
||
short_opname = op_name.strip("_")
|
||
if short_opname == "rtruediv":
|
||
# use the numpy version that won't raise on division by zero
|
||
return lambda x, y: np.divide(y, x)
|
||
elif short_opname == "rfloordiv":
|
||
return lambda x, y: np.floor_divide(y, x)
|
||
|
||
return tm.get_op_from_name(op_name)
|
||
|
||
def _patch_combine(self, obj, other, op):
|
||
# BaseOpsUtil._combine can upcast expected dtype
|
||
# (because it generates expected on python scalars)
|
||
# while ArrowExtensionArray maintains original type
|
||
expected = base.BaseArithmeticOpsTests._combine(self, obj, other, op)
|
||
was_frame = False
|
||
if isinstance(expected, pd.DataFrame):
|
||
was_frame = True
|
||
expected_data = expected.iloc[:, 0]
|
||
original_dtype = obj.iloc[:, 0].dtype
|
||
else:
|
||
expected_data = expected
|
||
original_dtype = obj.dtype
|
||
|
||
pa_expected = pa.array(expected_data._values)
|
||
|
||
if pa.types.is_duration(pa_expected.type):
|
||
# pyarrow sees sequence of datetime/timedelta objects and defaults
|
||
# to "us" but the non-pointwise op retains unit
|
||
unit = original_dtype.pyarrow_dtype.unit
|
||
if type(other) in [datetime, timedelta] and unit in ["s", "ms"]:
|
||
# pydatetime/pytimedelta objects have microsecond reso, so we
|
||
# take the higher reso of the original and microsecond. Note
|
||
# this matches what we would do with DatetimeArray/TimedeltaArray
|
||
unit = "us"
|
||
pa_expected = pa_expected.cast(f"duration[{unit}]")
|
||
else:
|
||
pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype)
|
||
|
||
pd_expected = type(expected_data._values)(pa_expected)
|
||
if was_frame:
|
||
expected = pd.DataFrame(
|
||
pd_expected, index=expected.index, columns=expected.columns
|
||
)
|
||
else:
|
||
expected = pd.Series(pd_expected)
|
||
return expected
|
||
|
||
def _is_temporal_supported(self, opname, pa_dtype):
|
||
return not pa_version_under8p0 and (
|
||
opname in ("__add__", "__radd__")
|
||
and pa.types.is_duration(pa_dtype)
|
||
or opname in ("__sub__", "__rsub__")
|
||
and pa.types.is_temporal(pa_dtype)
|
||
)
|
||
|
||
def _get_scalar_exception(self, opname, pa_dtype):
|
||
arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)
|
||
if opname in {
|
||
"__mod__",
|
||
"__rmod__",
|
||
}:
|
||
exc = NotImplementedError
|
||
elif arrow_temporal_supported:
|
||
exc = None
|
||
elif opname in ["__add__", "__radd__"] and (
|
||
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
|
||
):
|
||
exc = None
|
||
elif not (
|
||
pa.types.is_floating(pa_dtype)
|
||
or pa.types.is_integer(pa_dtype)
|
||
or pa.types.is_decimal(pa_dtype)
|
||
):
|
||
exc = pa.ArrowNotImplementedError
|
||
else:
|
||
exc = None
|
||
return exc
|
||
|
||
def _get_arith_xfail_marker(self, opname, pa_dtype):
|
||
mark = None
|
||
|
||
arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)
|
||
|
||
if (
|
||
opname == "__rpow__"
|
||
and (
|
||
pa.types.is_floating(pa_dtype)
|
||
or pa.types.is_integer(pa_dtype)
|
||
or pa.types.is_decimal(pa_dtype)
|
||
)
|
||
and not pa_version_under7p0
|
||
):
|
||
mark = pytest.mark.xfail(
|
||
reason=(
|
||
f"GH#29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL "
|
||
f"for {pa_dtype}"
|
||
)
|
||
)
|
||
elif arrow_temporal_supported:
|
||
mark = pytest.mark.xfail(
|
||
raises=TypeError,
|
||
reason=(
|
||
f"{opname} not supported between"
|
||
f"pd.NA and {pa_dtype} Python scalar"
|
||
),
|
||
)
|
||
elif (
|
||
opname == "__rfloordiv__"
|
||
and (pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype))
|
||
and not pa_version_under7p0
|
||
):
|
||
mark = pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason="divide by 0",
|
||
)
|
||
elif (
|
||
opname == "__rtruediv__"
|
||
and pa.types.is_decimal(pa_dtype)
|
||
and not pa_version_under7p0
|
||
):
|
||
mark = pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason="divide by 0",
|
||
)
|
||
elif (
|
||
opname == "__pow__"
|
||
and pa.types.is_decimal(pa_dtype)
|
||
and pa_version_under7p0
|
||
):
|
||
mark = pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason="Invalid decimal function: power_checked",
|
||
)
|
||
|
||
return mark
|
||
|
||
def test_arith_series_with_scalar(
|
||
self, data, all_arithmetic_operators, request, monkeypatch
|
||
):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
|
||
if all_arithmetic_operators == "__rmod__" and (
|
||
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
|
||
):
|
||
pytest.skip("Skip testing Python string formatting")
|
||
|
||
self.series_scalar_exc = self._get_scalar_exception(
|
||
all_arithmetic_operators, pa_dtype
|
||
)
|
||
|
||
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
|
||
if mark is not None:
|
||
request.node.add_marker(mark)
|
||
|
||
if (
|
||
(
|
||
all_arithmetic_operators == "__floordiv__"
|
||
and pa.types.is_integer(pa_dtype)
|
||
)
|
||
or pa.types.is_duration(pa_dtype)
|
||
or pa.types.is_timestamp(pa_dtype)
|
||
):
|
||
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
|
||
# not upcast
|
||
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
|
||
super().test_arith_series_with_scalar(data, all_arithmetic_operators)
|
||
|
||
def test_arith_frame_with_scalar(
|
||
self, data, all_arithmetic_operators, request, monkeypatch
|
||
):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
|
||
if all_arithmetic_operators == "__rmod__" and (
|
||
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
|
||
):
|
||
pytest.skip("Skip testing Python string formatting")
|
||
|
||
self.frame_scalar_exc = self._get_scalar_exception(
|
||
all_arithmetic_operators, pa_dtype
|
||
)
|
||
|
||
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
|
||
if mark is not None:
|
||
request.node.add_marker(mark)
|
||
|
||
if (
|
||
(
|
||
all_arithmetic_operators == "__floordiv__"
|
||
and pa.types.is_integer(pa_dtype)
|
||
)
|
||
or pa.types.is_duration(pa_dtype)
|
||
or pa.types.is_timestamp(pa_dtype)
|
||
):
|
||
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
|
||
# not upcast
|
||
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
|
||
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
|
||
|
||
def test_arith_series_with_array(
|
||
self, data, all_arithmetic_operators, request, monkeypatch
|
||
):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
|
||
self.series_array_exc = self._get_scalar_exception(
|
||
all_arithmetic_operators, pa_dtype
|
||
)
|
||
|
||
if (
|
||
all_arithmetic_operators
|
||
in (
|
||
"__sub__",
|
||
"__rsub__",
|
||
)
|
||
and pa.types.is_unsigned_integer(pa_dtype)
|
||
and not pa_version_under7p0
|
||
):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason=(
|
||
f"Implemented pyarrow.compute.subtract_checked "
|
||
f"which raises on overflow for {pa_dtype}"
|
||
),
|
||
)
|
||
)
|
||
|
||
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
|
||
if mark is not None:
|
||
request.node.add_marker(mark)
|
||
|
||
op_name = all_arithmetic_operators
|
||
ser = pd.Series(data)
|
||
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
|
||
# since ser.iloc[0] is a python scalar
|
||
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
|
||
|
||
if (
|
||
pa.types.is_floating(pa_dtype)
|
||
or (
|
||
pa.types.is_integer(pa_dtype)
|
||
and all_arithmetic_operators not in ["__truediv__", "__rtruediv__"]
|
||
)
|
||
or pa.types.is_duration(pa_dtype)
|
||
or pa.types.is_timestamp(pa_dtype)
|
||
):
|
||
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
|
||
self.check_opname(ser, op_name, other, exc=self.series_array_exc)
|
||
|
||
def test_add_series_with_extension_array(self, data, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
|
||
if pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype):
|
||
# i.e. timestamp, date, time, but not timedelta; these *should*
|
||
# raise when trying to add
|
||
ser = pd.Series(data)
|
||
if pa_version_under7p0:
|
||
msg = "Function add_checked has no kernel matching input types"
|
||
else:
|
||
msg = "Function 'add_checked' has no kernel matching input types"
|
||
with pytest.raises(NotImplementedError, match=msg):
|
||
# TODO: this is a pa.lib.ArrowNotImplementedError, might
|
||
# be better to reraise a TypeError; more consistent with
|
||
# non-pyarrow cases
|
||
ser + data
|
||
|
||
return
|
||
|
||
if (pa_version_under8p0 and pa.types.is_duration(pa_dtype)) or (
|
||
pa.types.is_boolean(pa_dtype)
|
||
):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=NotImplementedError,
|
||
reason=f"add_checked not implemented for {pa_dtype}",
|
||
)
|
||
)
|
||
elif pa_dtype.equals("int8"):
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason=f"raises on overflow for {pa_dtype}",
|
||
)
|
||
)
|
||
super().test_add_series_with_extension_array(data)
|
||
|
||
|
||
class TestBaseComparisonOps(base.BaseComparisonOpsTests):
|
||
def test_compare_array(self, data, comparison_op, na_value):
|
||
ser = pd.Series(data)
|
||
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
|
||
# since ser.iloc[0] is a python scalar
|
||
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
|
||
if comparison_op.__name__ in ["eq", "ne"]:
|
||
# comparison should match point-wise comparisons
|
||
result = comparison_op(ser, other)
|
||
# Series.combine does not calculate the NA mask correctly
|
||
# when comparing over an array
|
||
assert result[8] is na_value
|
||
assert result[97] is na_value
|
||
expected = ser.combine(other, comparison_op)
|
||
expected[8] = na_value
|
||
expected[97] = na_value
|
||
self.assert_series_equal(result, expected)
|
||
|
||
else:
|
||
exc = None
|
||
try:
|
||
result = comparison_op(ser, other)
|
||
except Exception as err:
|
||
exc = err
|
||
|
||
if exc is None:
|
||
# Didn't error, then should match point-wise behavior
|
||
expected = ser.combine(other, comparison_op)
|
||
self.assert_series_equal(result, expected)
|
||
else:
|
||
with pytest.raises(type(exc)):
|
||
ser.combine(other, comparison_op)
|
||
|
||
def test_invalid_other_comp(self, data, comparison_op):
|
||
# GH 48833
|
||
with pytest.raises(
|
||
NotImplementedError, match=".* not implemented for <class 'object'>"
|
||
):
|
||
comparison_op(data, object())
|
||
|
||
@pytest.mark.parametrize("masked_dtype", ["boolean", "Int64", "Float64"])
|
||
def test_comp_masked_numpy(self, masked_dtype, comparison_op):
|
||
# GH 52625
|
||
data = [1, 0, None]
|
||
ser_masked = pd.Series(data, dtype=masked_dtype)
|
||
ser_pa = pd.Series(data, dtype=f"{masked_dtype.lower()}[pyarrow]")
|
||
result = comparison_op(ser_pa, ser_masked)
|
||
if comparison_op in [operator.lt, operator.gt, operator.ne]:
|
||
exp = [False, False, None]
|
||
else:
|
||
exp = [True, True, None]
|
||
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
class TestLogicalOps:
|
||
"""Various Series and DataFrame logical ops methods."""
|
||
|
||
def test_kleene_or(self):
|
||
a = pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]")
|
||
b = pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
|
||
result = a | b
|
||
expected = pd.Series(
|
||
[True, True, True, True, False, None, True, None, None],
|
||
dtype="boolean[pyarrow]",
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = b | a
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
# ensure we haven't mutated anything inplace
|
||
tm.assert_series_equal(
|
||
a,
|
||
pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]"),
|
||
)
|
||
tm.assert_series_equal(
|
||
b, pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
|
||
)
|
||
|
||
@pytest.mark.parametrize(
|
||
"other, expected",
|
||
[
|
||
(None, [True, None, None]),
|
||
(pd.NA, [True, None, None]),
|
||
(True, [True, True, True]),
|
||
(np.bool_(True), [True, True, True]),
|
||
(False, [True, False, None]),
|
||
(np.bool_(False), [True, False, None]),
|
||
],
|
||
)
|
||
def test_kleene_or_scalar(self, other, expected):
|
||
a = pd.Series([True, False, None], dtype="boolean[pyarrow]")
|
||
result = a | other
|
||
expected = pd.Series(expected, dtype="boolean[pyarrow]")
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = other | a
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
# ensure we haven't mutated anything inplace
|
||
tm.assert_series_equal(
|
||
a, pd.Series([True, False, None], dtype="boolean[pyarrow]")
|
||
)
|
||
|
||
def test_kleene_and(self):
|
||
a = pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]")
|
||
b = pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
|
||
result = a & b
|
||
expected = pd.Series(
|
||
[True, False, None, False, False, False, None, False, None],
|
||
dtype="boolean[pyarrow]",
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = b & a
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
# ensure we haven't mutated anything inplace
|
||
tm.assert_series_equal(
|
||
a,
|
||
pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]"),
|
||
)
|
||
tm.assert_series_equal(
|
||
b, pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
|
||
)
|
||
|
||
@pytest.mark.parametrize(
|
||
"other, expected",
|
||
[
|
||
(None, [None, False, None]),
|
||
(pd.NA, [None, False, None]),
|
||
(True, [True, False, None]),
|
||
(False, [False, False, False]),
|
||
(np.bool_(True), [True, False, None]),
|
||
(np.bool_(False), [False, False, False]),
|
||
],
|
||
)
|
||
def test_kleene_and_scalar(self, other, expected):
|
||
a = pd.Series([True, False, None], dtype="boolean[pyarrow]")
|
||
result = a & other
|
||
expected = pd.Series(expected, dtype="boolean[pyarrow]")
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = other & a
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
# ensure we haven't mutated anything inplace
|
||
tm.assert_series_equal(
|
||
a, pd.Series([True, False, None], dtype="boolean[pyarrow]")
|
||
)
|
||
|
||
def test_kleene_xor(self):
|
||
a = pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]")
|
||
b = pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
|
||
result = a ^ b
|
||
expected = pd.Series(
|
||
[False, True, None, True, False, None, None, None, None],
|
||
dtype="boolean[pyarrow]",
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = b ^ a
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
# ensure we haven't mutated anything inplace
|
||
tm.assert_series_equal(
|
||
a,
|
||
pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]"),
|
||
)
|
||
tm.assert_series_equal(
|
||
b, pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
|
||
)
|
||
|
||
@pytest.mark.parametrize(
|
||
"other, expected",
|
||
[
|
||
(None, [None, None, None]),
|
||
(pd.NA, [None, None, None]),
|
||
(True, [False, True, None]),
|
||
(np.bool_(True), [False, True, None]),
|
||
(np.bool_(False), [True, False, None]),
|
||
],
|
||
)
|
||
def test_kleene_xor_scalar(self, other, expected):
|
||
a = pd.Series([True, False, None], dtype="boolean[pyarrow]")
|
||
result = a ^ other
|
||
expected = pd.Series(expected, dtype="boolean[pyarrow]")
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = other ^ a
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
# ensure we haven't mutated anything inplace
|
||
tm.assert_series_equal(
|
||
a, pd.Series([True, False, None], dtype="boolean[pyarrow]")
|
||
)
|
||
|
||
@pytest.mark.parametrize(
|
||
"op, exp",
|
||
[
|
||
["__and__", True],
|
||
["__or__", True],
|
||
["__xor__", False],
|
||
],
|
||
)
|
||
def test_logical_masked_numpy(self, op, exp):
|
||
# GH 52625
|
||
data = [True, False, None]
|
||
ser_masked = pd.Series(data, dtype="boolean")
|
||
ser_pa = pd.Series(data, dtype="boolean[pyarrow]")
|
||
result = getattr(ser_pa, op)(ser_masked)
|
||
expected = pd.Series([exp, False, None], dtype=ArrowDtype(pa.bool_()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
|
||
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
|
||
ArrowDtype.construct_from_string("not_a_real_dype[s, tz=UTC][pyarrow]")
|
||
|
||
# but as of GH#50689, timestamptz is supported
|
||
dtype = ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")
|
||
expected = ArrowDtype(pa.timestamp("s", "UTC"))
|
||
assert dtype == expected
|
||
|
||
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
|
||
ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]")
|
||
|
||
|
||
def test_arrowdtype_construct_from_string_type_only_one_pyarrow():
|
||
# GH#51225
|
||
invalid = "int64[pyarrow]foobar[pyarrow]"
|
||
msg = (
|
||
r"Passing pyarrow type specific parameters \(\[pyarrow\]\) in the "
|
||
r"string is not supported\."
|
||
)
|
||
with pytest.raises(NotImplementedError, match=msg):
|
||
pd.Series(range(3), dtype=invalid)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
|
||
)
|
||
@pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]])
|
||
def test_quantile(data, interpolation, quantile, request):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
|
||
data = data.take([0, 0, 0])
|
||
ser = pd.Series(data)
|
||
|
||
if (
|
||
pa.types.is_string(pa_dtype)
|
||
or pa.types.is_binary(pa_dtype)
|
||
or pa.types.is_boolean(pa_dtype)
|
||
):
|
||
# For string, bytes, and bool, we don't *expect* to have quantile work
|
||
# Note this matches the non-pyarrow behavior
|
||
if pa_version_under7p0:
|
||
msg = r"Function quantile has no kernel matching input types \(.*\)"
|
||
else:
|
||
msg = r"Function 'quantile' has no kernel matching input types \(.*\)"
|
||
with pytest.raises(pa.ArrowNotImplementedError, match=msg):
|
||
ser.quantile(q=quantile, interpolation=interpolation)
|
||
return
|
||
|
||
if (
|
||
pa.types.is_integer(pa_dtype)
|
||
or pa.types.is_floating(pa_dtype)
|
||
or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0)
|
||
):
|
||
pass
|
||
elif pa.types.is_temporal(data._data.type):
|
||
pass
|
||
else:
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowNotImplementedError,
|
||
reason=f"quantile not supported by pyarrow for {pa_dtype}",
|
||
)
|
||
)
|
||
data = data.take([0, 0, 0])
|
||
ser = pd.Series(data)
|
||
result = ser.quantile(q=quantile, interpolation=interpolation)
|
||
|
||
if pa.types.is_timestamp(pa_dtype) and interpolation not in ["lower", "higher"]:
|
||
# rounding error will make the check below fail
|
||
# (e.g. '2020-01-01 01:01:01.000001' vs '2020-01-01 01:01:01.000001024'),
|
||
# so we'll check for now that we match the numpy analogue
|
||
if pa_dtype.tz:
|
||
pd_dtype = f"M8[{pa_dtype.unit}, {pa_dtype.tz}]"
|
||
else:
|
||
pd_dtype = f"M8[{pa_dtype.unit}]"
|
||
ser_np = ser.astype(pd_dtype)
|
||
|
||
expected = ser_np.quantile(q=quantile, interpolation=interpolation)
|
||
if quantile == 0.5:
|
||
if pa_dtype.unit == "us":
|
||
expected = expected.to_pydatetime(warn=False)
|
||
assert result == expected
|
||
else:
|
||
if pa_dtype.unit == "us":
|
||
expected = expected.dt.floor("us")
|
||
tm.assert_series_equal(result, expected.astype(data.dtype))
|
||
return
|
||
|
||
if quantile == 0.5:
|
||
assert result == data[0]
|
||
else:
|
||
# Just check the values
|
||
expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5])
|
||
if (
|
||
pa.types.is_integer(pa_dtype)
|
||
or pa.types.is_floating(pa_dtype)
|
||
or pa.types.is_decimal(pa_dtype)
|
||
):
|
||
expected = expected.astype("float64[pyarrow]")
|
||
result = result.astype("float64[pyarrow]")
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"take_idx, exp_idx",
|
||
[[[0, 0, 2, 2, 4, 4], [0, 4]], [[0, 0, 0, 2, 4, 4], [0]]],
|
||
ids=["multi_mode", "single_mode"],
|
||
)
|
||
def test_mode_dropna_true(data_for_grouping, take_idx, exp_idx):
|
||
data = data_for_grouping.take(take_idx)
|
||
ser = pd.Series(data)
|
||
result = ser.mode(dropna=True)
|
||
expected = pd.Series(data_for_grouping.take(exp_idx))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_mode_dropna_false_mode_na(data):
|
||
# GH 50982
|
||
more_nans = pd.Series([None, None, data[0]], dtype=data.dtype)
|
||
result = more_nans.mode(dropna=False)
|
||
expected = pd.Series([None], dtype=data.dtype)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
expected = pd.Series([None, data[0]], dtype=data.dtype)
|
||
result = expected.mode(dropna=False)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"arrow_dtype, expected_type",
|
||
[
|
||
[pa.binary(), bytes],
|
||
[pa.binary(16), bytes],
|
||
[pa.large_binary(), bytes],
|
||
[pa.large_string(), str],
|
||
[pa.list_(pa.int64()), list],
|
||
[pa.large_list(pa.int64()), list],
|
||
[pa.map_(pa.string(), pa.int64()), list],
|
||
[pa.struct([("f1", pa.int8()), ("f2", pa.string())]), dict],
|
||
[pa.dictionary(pa.int64(), pa.int64()), CategoricalDtypeType],
|
||
],
|
||
)
|
||
def test_arrow_dtype_type(arrow_dtype, expected_type):
|
||
# GH 51845
|
||
# TODO: Redundant with test_getitem_scalar once arrow_dtype exists in data fixture
|
||
assert ArrowDtype(arrow_dtype).type == expected_type
|
||
|
||
|
||
def test_is_bool_dtype():
|
||
# GH 22667
|
||
data = ArrowExtensionArray(pa.array([True, False, True]))
|
||
assert is_bool_dtype(data)
|
||
assert pd.core.common.is_bool_indexer(data)
|
||
s = pd.Series(range(len(data)))
|
||
result = s[data]
|
||
expected = s[np.asarray(data)]
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_is_numeric_dtype(data):
|
||
# GH 50563
|
||
pa_type = data.dtype.pyarrow_dtype
|
||
if (
|
||
pa.types.is_floating(pa_type)
|
||
or pa.types.is_integer(pa_type)
|
||
or pa.types.is_decimal(pa_type)
|
||
):
|
||
assert is_numeric_dtype(data)
|
||
else:
|
||
assert not is_numeric_dtype(data)
|
||
|
||
|
||
def test_is_integer_dtype(data):
|
||
# GH 50667
|
||
pa_type = data.dtype.pyarrow_dtype
|
||
if pa.types.is_integer(pa_type):
|
||
assert is_integer_dtype(data)
|
||
else:
|
||
assert not is_integer_dtype(data)
|
||
|
||
|
||
def test_is_any_integer_dtype(data):
|
||
# GH 50667
|
||
pa_type = data.dtype.pyarrow_dtype
|
||
if pa.types.is_integer(pa_type):
|
||
assert is_any_int_dtype(data)
|
||
else:
|
||
assert not is_any_int_dtype(data)
|
||
|
||
|
||
def test_is_signed_integer_dtype(data):
|
||
pa_type = data.dtype.pyarrow_dtype
|
||
if pa.types.is_signed_integer(pa_type):
|
||
assert is_signed_integer_dtype(data)
|
||
else:
|
||
assert not is_signed_integer_dtype(data)
|
||
|
||
|
||
def test_is_unsigned_integer_dtype(data):
|
||
pa_type = data.dtype.pyarrow_dtype
|
||
if pa.types.is_unsigned_integer(pa_type):
|
||
assert is_unsigned_integer_dtype(data)
|
||
else:
|
||
assert not is_unsigned_integer_dtype(data)
|
||
|
||
|
||
def test_is_float_dtype(data):
|
||
pa_type = data.dtype.pyarrow_dtype
|
||
if pa.types.is_floating(pa_type):
|
||
assert is_float_dtype(data)
|
||
else:
|
||
assert not is_float_dtype(data)
|
||
|
||
|
||
def test_pickle_roundtrip(data):
|
||
# GH 42600
|
||
expected = pd.Series(data)
|
||
expected_sliced = expected.head(2)
|
||
full_pickled = pickle.dumps(expected)
|
||
sliced_pickled = pickle.dumps(expected_sliced)
|
||
|
||
assert len(full_pickled) > len(sliced_pickled)
|
||
|
||
result = pickle.loads(full_pickled)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result_sliced = pickle.loads(sliced_pickled)
|
||
tm.assert_series_equal(result_sliced, expected_sliced)
|
||
|
||
|
||
def test_astype_from_non_pyarrow(data):
|
||
# GH49795
|
||
pd_array = data._data.to_pandas().array
|
||
result = pd_array.astype(data.dtype)
|
||
assert not isinstance(pd_array.dtype, ArrowDtype)
|
||
assert isinstance(result.dtype, ArrowDtype)
|
||
tm.assert_extension_array_equal(result, data)
|
||
|
||
|
||
def test_astype_float_from_non_pyarrow_str():
|
||
# GH50430
|
||
ser = pd.Series(["1.0"])
|
||
result = ser.astype("float64[pyarrow]")
|
||
expected = pd.Series([1.0], dtype="float64[pyarrow]")
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_to_numpy_with_defaults(data):
|
||
# GH49973
|
||
result = data.to_numpy()
|
||
|
||
pa_type = data._data.type
|
||
if pa.types.is_duration(pa_type) or pa.types.is_timestamp(pa_type):
|
||
expected = np.array(list(data))
|
||
else:
|
||
expected = np.array(data._data)
|
||
|
||
if data._hasna:
|
||
expected = expected.astype(object)
|
||
expected[pd.isna(data)] = pd.NA
|
||
|
||
tm.assert_numpy_array_equal(result, expected)
|
||
|
||
|
||
def test_to_numpy_int_with_na():
|
||
# GH51227: ensure to_numpy does not convert int to float
|
||
data = [1, None]
|
||
arr = pd.array(data, dtype="int64[pyarrow]")
|
||
result = arr.to_numpy()
|
||
expected = np.array([1, pd.NA], dtype=object)
|
||
assert isinstance(result[0], int)
|
||
tm.assert_numpy_array_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("na_val, exp", [(lib.no_default, np.nan), (1, 1)])
|
||
def test_to_numpy_null_array(na_val, exp):
|
||
# GH#52443
|
||
arr = pd.array([pd.NA, pd.NA], dtype="null[pyarrow]")
|
||
result = arr.to_numpy(dtype="float64", na_value=na_val)
|
||
expected = np.array([exp] * 2, dtype="float64")
|
||
tm.assert_numpy_array_equal(result, expected)
|
||
|
||
|
||
def test_to_numpy_null_array_no_dtype():
|
||
# GH#52443
|
||
arr = pd.array([pd.NA, pd.NA], dtype="null[pyarrow]")
|
||
result = arr.to_numpy(dtype=None)
|
||
expected = np.array([pd.NA] * 2, dtype="object")
|
||
tm.assert_numpy_array_equal(result, expected)
|
||
|
||
|
||
def test_setitem_null_slice(data):
|
||
# GH50248
|
||
orig = data.copy()
|
||
|
||
result = orig.copy()
|
||
result[:] = data[0]
|
||
expected = ArrowExtensionArray(
|
||
pa.array([data[0]] * len(data), type=data._data.type)
|
||
)
|
||
tm.assert_extension_array_equal(result, expected)
|
||
|
||
result = orig.copy()
|
||
result[:] = data[::-1]
|
||
expected = data[::-1]
|
||
tm.assert_extension_array_equal(result, expected)
|
||
|
||
result = orig.copy()
|
||
result[:] = data.tolist()
|
||
expected = data
|
||
tm.assert_extension_array_equal(result, expected)
|
||
|
||
|
||
def test_setitem_invalid_dtype(data):
|
||
# GH50248
|
||
pa_type = data._data.type
|
||
if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
|
||
fill_value = 123
|
||
err = TypeError
|
||
msg = "Invalid value '123' for dtype"
|
||
elif (
|
||
pa.types.is_integer(pa_type)
|
||
or pa.types.is_floating(pa_type)
|
||
or pa.types.is_boolean(pa_type)
|
||
):
|
||
fill_value = "foo"
|
||
err = pa.ArrowInvalid
|
||
msg = "Could not convert"
|
||
else:
|
||
fill_value = "foo"
|
||
err = TypeError
|
||
msg = "Invalid value 'foo' for dtype"
|
||
with pytest.raises(err, match=msg):
|
||
data[:] = fill_value
|
||
|
||
|
||
@pytest.mark.skipif(pa_version_under8p0, reason="returns object with 7.0")
|
||
def test_from_arrow_respecting_given_dtype():
|
||
date_array = pa.array(
|
||
[pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")], type=pa.date32()
|
||
)
|
||
result = date_array.to_pandas(
|
||
types_mapper={pa.date32(): ArrowDtype(pa.date64())}.get
|
||
)
|
||
expected = pd.Series(
|
||
[pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")],
|
||
dtype=ArrowDtype(pa.date64()),
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.skipif(pa_version_under8p0, reason="doesn't raise with 7")
|
||
def test_from_arrow_respecting_given_dtype_unsafe():
|
||
array = pa.array([1.5, 2.5], type=pa.float64())
|
||
with pytest.raises(pa.ArrowInvalid, match="Float value 1.5 was truncated"):
|
||
array.to_pandas(types_mapper={pa.float64(): ArrowDtype(pa.int64())}.get)
|
||
|
||
|
||
def test_round():
|
||
dtype = "float64[pyarrow]"
|
||
|
||
ser = pd.Series([0.0, 1.23, 2.56, pd.NA], dtype=dtype)
|
||
result = ser.round(1)
|
||
expected = pd.Series([0.0, 1.2, 2.6, pd.NA], dtype=dtype)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
ser = pd.Series([123.4, pd.NA, 56.78], dtype=dtype)
|
||
result = ser.round(-1)
|
||
expected = pd.Series([120.0, pd.NA, 60.0], dtype=dtype)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_searchsorted_with_na_raises(data_for_sorting, as_series):
|
||
# GH50447
|
||
b, c, a = data_for_sorting
|
||
arr = data_for_sorting.take([2, 0, 1]) # to get [a, b, c]
|
||
arr[-1] = pd.NA
|
||
|
||
if as_series:
|
||
arr = pd.Series(arr)
|
||
|
||
msg = (
|
||
"searchsorted requires array to be sorted, "
|
||
"which is impossible with NAs present."
|
||
)
|
||
with pytest.raises(ValueError, match=msg):
|
||
arr.searchsorted(b)
|
||
|
||
|
||
def test_sort_values_dictionary():
|
||
df = pd.DataFrame(
|
||
{
|
||
"a": pd.Series(
|
||
["x", "y"], dtype=ArrowDtype(pa.dictionary(pa.int32(), pa.string()))
|
||
),
|
||
"b": [1, 2],
|
||
},
|
||
)
|
||
expected = df.copy()
|
||
result = df.sort_values(by=["a", "b"])
|
||
tm.assert_frame_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("pat", ["abc", "a[a-z]{2}"])
|
||
def test_str_count(pat):
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.count(pat)
|
||
expected = pd.Series([1, None], dtype=ArrowDtype(pa.int32()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_str_count_flags_unsupported():
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
with pytest.raises(NotImplementedError, match="count not"):
|
||
ser.str.count("abc", flags=1)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"side, str_func", [["left", "rjust"], ["right", "ljust"], ["both", "center"]]
|
||
)
|
||
def test_str_pad(side, str_func):
|
||
ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.pad(width=3, side=side, fillchar="x")
|
||
expected = pd.Series(
|
||
[getattr("a", str_func)(3, "x"), None], dtype=ArrowDtype(pa.string())
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_str_pad_invalid_side():
|
||
ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string()))
|
||
with pytest.raises(ValueError, match="Invalid side: foo"):
|
||
ser.str.pad(3, "foo", "x")
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"pat, case, na, regex, exp",
|
||
[
|
||
["ab", False, None, False, [True, None]],
|
||
["Ab", True, None, False, [False, None]],
|
||
["ab", False, True, False, [True, True]],
|
||
["a[a-z]{1}", False, None, True, [True, None]],
|
||
["A[a-z]{1}", True, None, True, [False, None]],
|
||
],
|
||
)
|
||
def test_str_contains(pat, case, na, regex, exp):
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.contains(pat, case=case, na=na, regex=regex)
|
||
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_str_contains_flags_unsupported():
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
with pytest.raises(NotImplementedError, match="contains not"):
|
||
ser.str.contains("a", flags=1)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"side, pat, na, exp",
|
||
[
|
||
["startswith", "ab", None, [True, None]],
|
||
["startswith", "b", False, [False, False]],
|
||
["endswith", "b", True, [False, True]],
|
||
["endswith", "bc", None, [True, None]],
|
||
],
|
||
)
|
||
def test_str_start_ends_with(side, pat, na, exp):
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
result = getattr(ser.str, side)(pat, na=na)
|
||
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"arg_name, arg",
|
||
[["pat", re.compile("b")], ["repl", str], ["case", False], ["flags", 1]],
|
||
)
|
||
def test_str_replace_unsupported(arg_name, arg):
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
kwargs = {"pat": "b", "repl": "x", "regex": True}
|
||
kwargs[arg_name] = arg
|
||
with pytest.raises(NotImplementedError, match="replace is not supported"):
|
||
ser.str.replace(**kwargs)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"pat, repl, n, regex, exp",
|
||
[
|
||
["a", "x", -1, False, ["xbxc", None]],
|
||
["a", "x", 1, False, ["xbac", None]],
|
||
["[a-b]", "x", -1, True, ["xxxc", None]],
|
||
],
|
||
)
|
||
def test_str_replace(pat, repl, n, regex, exp):
|
||
ser = pd.Series(["abac", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.replace(pat, repl, n=n, regex=regex)
|
||
expected = pd.Series(exp, dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_str_repeat_unsupported():
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
with pytest.raises(NotImplementedError, match="repeat is not"):
|
||
ser.str.repeat([1, 2])
|
||
|
||
|
||
@pytest.mark.xfail(
|
||
pa_version_under7p0,
|
||
reason="Unsupported for pyarrow < 7",
|
||
raises=NotImplementedError,
|
||
)
|
||
def test_str_repeat():
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.repeat(2)
|
||
expected = pd.Series(["abcabc", None], dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"pat, case, na, exp",
|
||
[
|
||
["ab", False, None, [True, None]],
|
||
["Ab", True, None, [False, None]],
|
||
["bc", True, None, [False, None]],
|
||
["ab", False, True, [True, True]],
|
||
["a[a-z]{1}", False, None, [True, None]],
|
||
["A[a-z]{1}", True, None, [False, None]],
|
||
],
|
||
)
|
||
def test_str_match(pat, case, na, exp):
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.match(pat, case=case, na=na)
|
||
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"pat, case, na, exp",
|
||
[
|
||
["abc", False, None, [True, None]],
|
||
["Abc", True, None, [False, None]],
|
||
["bc", True, None, [False, None]],
|
||
["ab", False, True, [True, True]],
|
||
["a[a-z]{2}", False, None, [True, None]],
|
||
["A[a-z]{1}", True, None, [False, None]],
|
||
],
|
||
)
|
||
def test_str_fullmatch(pat, case, na, exp):
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.match(pat, case=case, na=na)
|
||
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"sub, start, end, exp, exp_typ",
|
||
[["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [2, None], pa.int64()]],
|
||
)
|
||
def test_str_find(sub, start, end, exp, exp_typ):
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.find(sub, start=start, end=end)
|
||
expected = pd.Series(exp, dtype=ArrowDtype(exp_typ))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_str_find_notimplemented():
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
with pytest.raises(NotImplementedError, match="find not implemented"):
|
||
ser.str.find("ab", start=1)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"i, exp",
|
||
[
|
||
[1, ["b", "e", None]],
|
||
[-1, ["c", "e", None]],
|
||
[2, ["c", None, None]],
|
||
[-3, ["a", None, None]],
|
||
[4, [None, None, None]],
|
||
],
|
||
)
|
||
def test_str_get(i, exp):
|
||
ser = pd.Series(["abc", "de", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.get(i)
|
||
expected = pd.Series(exp, dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.xfail(
|
||
reason="TODO: StringMethods._validate should support Arrow list types",
|
||
raises=AttributeError,
|
||
)
|
||
def test_str_join():
|
||
ser = pd.Series(ArrowExtensionArray(pa.array([list("abc"), list("123"), None])))
|
||
result = ser.str.join("=")
|
||
expected = pd.Series(["a=b=c", "1=2=3", None], dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"start, stop, step, exp",
|
||
[
|
||
[None, 2, None, ["ab", None]],
|
||
[None, 2, 1, ["ab", None]],
|
||
[1, 3, 1, ["bc", None]],
|
||
],
|
||
)
|
||
def test_str_slice(start, stop, step, exp):
|
||
ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.slice(start, stop, step)
|
||
expected = pd.Series(exp, dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"start, stop, repl, exp",
|
||
[
|
||
[1, 2, "x", ["axcd", None]],
|
||
[None, 2, "x", ["xcd", None]],
|
||
[None, 2, None, ["cd", None]],
|
||
],
|
||
)
|
||
def test_str_slice_replace(start, stop, repl, exp):
|
||
ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.slice_replace(start, stop, repl)
|
||
expected = pd.Series(exp, dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"value, method, exp",
|
||
[
|
||
["a1c", "isalnum", True],
|
||
["!|,", "isalnum", False],
|
||
["aaa", "isalpha", True],
|
||
["!!!", "isalpha", False],
|
||
["٠", "isdecimal", True],
|
||
["~!", "isdecimal", False],
|
||
["2", "isdigit", True],
|
||
["~", "isdigit", False],
|
||
["aaa", "islower", True],
|
||
["aaA", "islower", False],
|
||
["123", "isnumeric", True],
|
||
["11I", "isnumeric", False],
|
||
[" ", "isspace", True],
|
||
["", "isspace", False],
|
||
["The That", "istitle", True],
|
||
["the That", "istitle", False],
|
||
["AAA", "isupper", True],
|
||
["AAc", "isupper", False],
|
||
],
|
||
)
|
||
def test_str_is_functions(value, method, exp):
|
||
ser = pd.Series([value, None], dtype=ArrowDtype(pa.string()))
|
||
result = getattr(ser.str, method)()
|
||
expected = pd.Series([exp, None], dtype=ArrowDtype(pa.bool_()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"method, exp",
|
||
[
|
||
["capitalize", "Abc def"],
|
||
["title", "Abc Def"],
|
||
["swapcase", "AbC Def"],
|
||
["lower", "abc def"],
|
||
["upper", "ABC DEF"],
|
||
["casefold", "abc def"],
|
||
],
|
||
)
|
||
def test_str_transform_functions(method, exp):
|
||
ser = pd.Series(["aBc dEF", None], dtype=ArrowDtype(pa.string()))
|
||
result = getattr(ser.str, method)()
|
||
expected = pd.Series([exp, None], dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_str_len():
|
||
ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.len()
|
||
expected = pd.Series([4, None], dtype=ArrowDtype(pa.int32()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"method, to_strip, val",
|
||
[
|
||
["strip", None, " abc "],
|
||
["strip", "x", "xabcx"],
|
||
["lstrip", None, " abc"],
|
||
["lstrip", "x", "xabc"],
|
||
["rstrip", None, "abc "],
|
||
["rstrip", "x", "abcx"],
|
||
],
|
||
)
|
||
def test_str_strip(method, to_strip, val):
|
||
ser = pd.Series([val, None], dtype=ArrowDtype(pa.string()))
|
||
result = getattr(ser.str, method)(to_strip=to_strip)
|
||
expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("val", ["abc123", "abc"])
|
||
def test_str_removesuffix(val):
|
||
ser = pd.Series([val, None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.removesuffix("123")
|
||
expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("val", ["123abc", "abc"])
|
||
def test_str_removeprefix(val):
|
||
ser = pd.Series([val, None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.removeprefix("123")
|
||
expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("errors", ["ignore", "strict"])
|
||
@pytest.mark.parametrize(
|
||
"encoding, exp",
|
||
[
|
||
["utf8", b"abc"],
|
||
["utf32", b"\xff\xfe\x00\x00a\x00\x00\x00b\x00\x00\x00c\x00\x00\x00"],
|
||
],
|
||
)
|
||
def test_str_encode(errors, encoding, exp):
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.encode(encoding, errors)
|
||
expected = pd.Series([exp, None], dtype=ArrowDtype(pa.binary()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("flags", [0, 1])
|
||
def test_str_findall(flags):
|
||
ser = pd.Series(["abc", "efg", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.findall("b", flags=flags)
|
||
expected = pd.Series([["b"], [], None], dtype=ArrowDtype(pa.list_(pa.string())))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("method", ["index", "rindex"])
|
||
@pytest.mark.parametrize(
|
||
"start, end",
|
||
[
|
||
[0, None],
|
||
[1, 4],
|
||
],
|
||
)
|
||
def test_str_r_index(method, start, end):
|
||
ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string()))
|
||
result = getattr(ser.str, method)("c", start, end)
|
||
expected = pd.Series([2, None], dtype=ArrowDtype(pa.int64()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
with pytest.raises(ValueError, match="substring not found"):
|
||
getattr(ser.str, method)("foo", start, end)
|
||
|
||
|
||
@pytest.mark.parametrize("form", ["NFC", "NFKC"])
|
||
def test_str_normalize(form):
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.normalize(form)
|
||
expected = ser.copy()
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"start, end",
|
||
[
|
||
[0, None],
|
||
[1, 4],
|
||
],
|
||
)
|
||
def test_str_rfind(start, end):
|
||
ser = pd.Series(["abcba", "foo", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.rfind("c", start, end)
|
||
expected = pd.Series([2, -1, None], dtype=ArrowDtype(pa.int64()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_str_translate():
|
||
ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.translate({97: "b"})
|
||
expected = pd.Series(["bbcbb", None], dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_str_wrap():
|
||
ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.wrap(3)
|
||
expected = pd.Series(["abc\nba", None], dtype=ArrowDtype(pa.string()))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_get_dummies():
|
||
ser = pd.Series(["a|b", None, "a|c"], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.get_dummies()
|
||
expected = pd.DataFrame(
|
||
[[True, True, False], [False, False, False], [True, False, True]],
|
||
dtype=ArrowDtype(pa.bool_()),
|
||
columns=["a", "b", "c"],
|
||
)
|
||
tm.assert_frame_equal(result, expected)
|
||
|
||
|
||
def test_str_partition():
|
||
ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.partition("b")
|
||
expected = pd.DataFrame(
|
||
[["a", "b", "cba"], [None, None, None]], dtype=ArrowDtype(pa.string())
|
||
)
|
||
tm.assert_frame_equal(result, expected)
|
||
|
||
result = ser.str.partition("b", expand=False)
|
||
expected = pd.Series(ArrowExtensionArray(pa.array([["a", "b", "cba"], None])))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = ser.str.rpartition("b")
|
||
expected = pd.DataFrame(
|
||
[["abc", "b", "a"], [None, None, None]], dtype=ArrowDtype(pa.string())
|
||
)
|
||
tm.assert_frame_equal(result, expected)
|
||
|
||
result = ser.str.rpartition("b", expand=False)
|
||
expected = pd.Series(ArrowExtensionArray(pa.array([["abc", "b", "a"], None])))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_str_split():
|
||
# GH 52401
|
||
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.split("c")
|
||
expected = pd.Series(
|
||
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = ser.str.split("c", n=1)
|
||
expected = pd.Series(
|
||
ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None]))
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = ser.str.split("[1-2]", regex=True)
|
||
expected = pd.Series(
|
||
ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None]))
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = ser.str.split("[1-2]", regex=True, expand=True)
|
||
expected = pd.DataFrame(
|
||
{
|
||
0: ArrowExtensionArray(pa.array(["a", "a", None])),
|
||
1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])),
|
||
}
|
||
)
|
||
tm.assert_frame_equal(result, expected)
|
||
|
||
|
||
def test_str_rsplit():
|
||
# GH 52401
|
||
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
|
||
result = ser.str.rsplit("c")
|
||
expected = pd.Series(
|
||
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = ser.str.rsplit("c", n=1)
|
||
expected = pd.Series(
|
||
ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None]))
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
result = ser.str.rsplit("c", n=1, expand=True)
|
||
expected = pd.DataFrame(
|
||
{
|
||
0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])),
|
||
1: ArrowExtensionArray(pa.array(["b", "b", None])),
|
||
}
|
||
)
|
||
tm.assert_frame_equal(result, expected)
|
||
|
||
|
||
def test_str_unsupported_extract():
|
||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
|
||
with pytest.raises(
|
||
NotImplementedError, match="str.extract not supported with pd.ArrowDtype"
|
||
):
|
||
ser.str.extract(r"[ab](\d)")
|
||
|
||
|
||
@pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"])
|
||
def test_duration_from_strings_with_nat(unit):
|
||
# GH51175
|
||
strings = ["1000", "NaT"]
|
||
pa_type = pa.duration(unit)
|
||
result = ArrowExtensionArray._from_sequence_of_strings(strings, dtype=pa_type)
|
||
expected = ArrowExtensionArray(pa.array([1000, None], type=pa_type))
|
||
tm.assert_extension_array_equal(result, expected)
|
||
|
||
|
||
def test_unsupported_dt(data):
|
||
pa_dtype = data.dtype.pyarrow_dtype
|
||
if not pa.types.is_temporal(pa_dtype):
|
||
with pytest.raises(
|
||
AttributeError, match="Can only use .dt accessor with datetimelike values"
|
||
):
|
||
pd.Series(data).dt
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"prop, expected",
|
||
[
|
||
["year", 2023],
|
||
["day", 2],
|
||
["day_of_week", 0],
|
||
["dayofweek", 0],
|
||
["weekday", 0],
|
||
["day_of_year", 2],
|
||
["dayofyear", 2],
|
||
["hour", 3],
|
||
["minute", 4],
|
||
pytest.param(
|
||
"is_leap_year",
|
||
False,
|
||
marks=pytest.mark.xfail(
|
||
pa_version_under8p0,
|
||
raises=NotImplementedError,
|
||
reason="is_leap_year not implemented for pyarrow < 8.0",
|
||
),
|
||
),
|
||
["microsecond", 5],
|
||
["month", 1],
|
||
["nanosecond", 6],
|
||
["quarter", 1],
|
||
["second", 7],
|
||
["date", date(2023, 1, 2)],
|
||
["time", time(3, 4, 7, 5)],
|
||
],
|
||
)
|
||
def test_dt_properties(prop, expected):
|
||
ser = pd.Series(
|
||
[
|
||
pd.Timestamp(
|
||
year=2023,
|
||
month=1,
|
||
day=2,
|
||
hour=3,
|
||
minute=4,
|
||
second=7,
|
||
microsecond=5,
|
||
nanosecond=6,
|
||
),
|
||
None,
|
||
],
|
||
dtype=ArrowDtype(pa.timestamp("ns")),
|
||
)
|
||
result = getattr(ser.dt, prop)
|
||
exp_type = None
|
||
if isinstance(expected, date):
|
||
exp_type = pa.date32()
|
||
elif isinstance(expected, time):
|
||
exp_type = pa.time64("ns")
|
||
expected = pd.Series(ArrowExtensionArray(pa.array([expected, None], type=exp_type)))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("unit", ["us", "ns"])
|
||
def test_dt_time_preserve_unit(unit):
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp(unit)),
|
||
)
|
||
result = ser.dt.time
|
||
expected = pd.Series(
|
||
ArrowExtensionArray(pa.array([time(3, 0), None], type=pa.time64(unit)))
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("tz", [None, "UTC", "US/Pacific"])
|
||
def test_dt_tz(tz):
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp("ns", tz=tz)),
|
||
)
|
||
result = ser.dt.tz
|
||
assert result == tz
|
||
|
||
|
||
def test_dt_isocalendar():
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp("ns")),
|
||
)
|
||
result = ser.dt.isocalendar()
|
||
expected = pd.DataFrame(
|
||
[[2023, 1, 1], [0, 0, 0]],
|
||
columns=["year", "week", "day"],
|
||
dtype="int64[pyarrow]",
|
||
)
|
||
tm.assert_frame_equal(result, expected)
|
||
|
||
|
||
def test_dt_strftime(request):
|
||
if is_platform_windows() and is_ci_environment():
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason=(
|
||
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
|
||
"on CI to path to the tzdata for pyarrow."
|
||
),
|
||
)
|
||
)
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp("ns")),
|
||
)
|
||
result = ser.dt.strftime("%Y-%m-%dT%H:%M:%S")
|
||
expected = pd.Series(
|
||
["2023-01-02T03:00:00.000000000", None], dtype=ArrowDtype(pa.string())
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("method", ["ceil", "floor", "round"])
|
||
def test_dt_roundlike_tz_options_not_supported(method):
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp("ns")),
|
||
)
|
||
with pytest.raises(NotImplementedError, match="ambiguous is not supported."):
|
||
getattr(ser.dt, method)("1H", ambiguous="NaT")
|
||
|
||
with pytest.raises(NotImplementedError, match="nonexistent is not supported."):
|
||
getattr(ser.dt, method)("1H", nonexistent="NaT")
|
||
|
||
|
||
@pytest.mark.parametrize("method", ["ceil", "floor", "round"])
|
||
def test_dt_roundlike_unsupported_freq(method):
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp("ns")),
|
||
)
|
||
with pytest.raises(ValueError, match="freq='1B' is not supported"):
|
||
getattr(ser.dt, method)("1B")
|
||
|
||
with pytest.raises(ValueError, match="Must specify a valid frequency: None"):
|
||
getattr(ser.dt, method)(None)
|
||
|
||
|
||
@pytest.mark.xfail(
|
||
pa_version_under7p0, reason="Methods not supported for pyarrow < 7.0"
|
||
)
|
||
@pytest.mark.parametrize("freq", ["D", "H", "T", "S", "L", "U", "N"])
|
||
@pytest.mark.parametrize("method", ["ceil", "floor", "round"])
|
||
def test_dt_ceil_year_floor(freq, method):
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=1), None],
|
||
)
|
||
pa_dtype = ArrowDtype(pa.timestamp("ns"))
|
||
expected = getattr(ser.dt, method)(f"1{freq}").astype(pa_dtype)
|
||
result = getattr(ser.astype(pa_dtype).dt, method)(f"1{freq}")
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
def test_dt_to_pydatetime():
|
||
# GH 51859
|
||
data = [datetime(2022, 1, 1), datetime(2023, 1, 1)]
|
||
ser = pd.Series(data, dtype=ArrowDtype(pa.timestamp("ns")))
|
||
|
||
result = ser.dt.to_pydatetime()
|
||
expected = np.array(data, dtype=object)
|
||
tm.assert_numpy_array_equal(result, expected)
|
||
assert all(type(res) is datetime for res in result)
|
||
|
||
expected = ser.astype("datetime64[ns]").dt.to_pydatetime()
|
||
tm.assert_numpy_array_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("date_type", [32, 64])
|
||
def test_dt_to_pydatetime_date_error(date_type):
|
||
# GH 52812
|
||
ser = pd.Series(
|
||
[date(2022, 12, 31)],
|
||
dtype=ArrowDtype(getattr(pa, f"date{date_type}")()),
|
||
)
|
||
with pytest.raises(ValueError, match="to_pydatetime cannot be called with"):
|
||
ser.dt.to_pydatetime()
|
||
|
||
|
||
def test_dt_tz_localize_unsupported_tz_options():
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp("ns")),
|
||
)
|
||
with pytest.raises(NotImplementedError, match="ambiguous='NaT' is not supported"):
|
||
ser.dt.tz_localize("UTC", ambiguous="NaT")
|
||
|
||
with pytest.raises(NotImplementedError, match="nonexistent='NaT' is not supported"):
|
||
ser.dt.tz_localize("UTC", nonexistent="NaT")
|
||
|
||
|
||
def test_dt_tz_localize_none():
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp("ns", tz="US/Pacific")),
|
||
)
|
||
result = ser.dt.tz_localize(None)
|
||
expected = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp("ns")),
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("unit", ["us", "ns"])
|
||
def test_dt_tz_localize(unit, request):
|
||
if is_platform_windows() and is_ci_environment():
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason=(
|
||
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
|
||
"on CI to path to the tzdata for pyarrow."
|
||
),
|
||
)
|
||
)
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None],
|
||
dtype=ArrowDtype(pa.timestamp(unit)),
|
||
)
|
||
result = ser.dt.tz_localize("US/Pacific")
|
||
exp_data = pa.array(
|
||
[datetime(year=2023, month=1, day=2, hour=3), None], type=pa.timestamp(unit)
|
||
)
|
||
exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
|
||
expected = pd.Series(ArrowExtensionArray(exp_data))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"nonexistent, exp_date",
|
||
[
|
||
["shift_forward", datetime(year=2023, month=3, day=12, hour=3)],
|
||
["shift_backward", pd.Timestamp("2023-03-12 01:59:59.999999999")],
|
||
],
|
||
)
|
||
def test_dt_tz_localize_nonexistent(nonexistent, exp_date, request):
|
||
if is_platform_windows() and is_ci_environment():
|
||
request.node.add_marker(
|
||
pytest.mark.xfail(
|
||
raises=pa.ArrowInvalid,
|
||
reason=(
|
||
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
|
||
"on CI to path to the tzdata for pyarrow."
|
||
),
|
||
)
|
||
)
|
||
ser = pd.Series(
|
||
[datetime(year=2023, month=3, day=12, hour=2, minute=30), None],
|
||
dtype=ArrowDtype(pa.timestamp("ns")),
|
||
)
|
||
result = ser.dt.tz_localize("US/Pacific", nonexistent=nonexistent)
|
||
exp_data = pa.array([exp_date, None], type=pa.timestamp("ns"))
|
||
exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
|
||
expected = pd.Series(ArrowExtensionArray(exp_data))
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("skipna", [True, False])
|
||
def test_boolean_reduce_series_all_null(all_boolean_reductions, skipna):
|
||
# GH51624
|
||
ser = pd.Series([None], dtype="float64[pyarrow]")
|
||
result = getattr(ser, all_boolean_reductions)(skipna=skipna)
|
||
if skipna:
|
||
expected = all_boolean_reductions == "all"
|
||
else:
|
||
expected = pd.NA
|
||
assert result is expected
|
||
|
||
|
||
@pytest.mark.parametrize("dtype", ["string", "string[pyarrow]"])
|
||
def test_series_from_string_array(dtype):
|
||
arr = pa.array("the quick brown fox".split())
|
||
ser = pd.Series(arr, dtype=dtype)
|
||
expected = pd.Series(ArrowExtensionArray(arr), dtype=dtype)
|
||
tm.assert_series_equal(ser, expected)
|
||
|
||
|
||
def test_setitem_boolean_replace_with_mask_segfault():
|
||
# GH#52059
|
||
N = 145_000
|
||
arr = ArrowExtensionArray(pa.chunked_array([np.ones((N,), dtype=np.bool_)]))
|
||
expected = arr.copy()
|
||
arr[np.zeros((N,), dtype=np.bool_)] = False
|
||
assert arr._data == expected._data
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
"data, arrow_dtype",
|
||
[
|
||
([b"a", b"b"], pa.large_binary()),
|
||
(["a", "b"], pa.large_string()),
|
||
],
|
||
)
|
||
def test_conversion_large_dtypes_from_numpy_array(data, arrow_dtype):
|
||
dtype = ArrowDtype(arrow_dtype)
|
||
result = pd.array(np.array(data), dtype=dtype)
|
||
expected = pd.array(data, dtype=dtype)
|
||
tm.assert_extension_array_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("pa_type", tm.ALL_INT_PYARROW_DTYPES + tm.FLOAT_PYARROW_DTYPES)
|
||
def test_describe_numeric_data(pa_type):
|
||
# GH 52470
|
||
data = pd.Series([1, 2, 3], dtype=ArrowDtype(pa_type))
|
||
result = data.describe()
|
||
expected = pd.Series(
|
||
[3, 2, 1, 1, 1.5, 2.0, 2.5, 3],
|
||
dtype=ArrowDtype(pa.float64()),
|
||
index=["count", "mean", "std", "min", "25%", "50%", "75%", "max"],
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("pa_type", tm.TIMEDELTA_PYARROW_DTYPES)
|
||
def test_describe_timedelta_data(pa_type):
|
||
# GH53001
|
||
data = pd.Series(range(1, 10), dtype=ArrowDtype(pa_type))
|
||
result = data.describe()
|
||
expected = pd.Series(
|
||
[9] + pd.to_timedelta([5, 2, 1, 3, 5, 7, 9], unit=pa_type.unit).tolist(),
|
||
dtype=object,
|
||
index=["count", "mean", "std", "min", "25%", "50%", "75%", "max"],
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.parametrize("pa_type", tm.DATETIME_PYARROW_DTYPES)
|
||
def test_describe_datetime_data(pa_type):
|
||
# GH53001
|
||
data = pd.Series(range(1, 10), dtype=ArrowDtype(pa_type))
|
||
result = data.describe()
|
||
expected = pd.Series(
|
||
[9]
|
||
+ [
|
||
pd.Timestamp(v, tz=pa_type.tz, unit=pa_type.unit)
|
||
for v in [5, 1, 3, 5, 7, 9]
|
||
],
|
||
dtype=object,
|
||
index=["count", "mean", "min", "25%", "50%", "75%", "max"],
|
||
)
|
||
tm.assert_series_equal(result, expected)
|
||
|
||
|
||
@pytest.mark.xfail(
|
||
pa_version_under8p0,
|
||
reason="Function 'add_checked' has no kernel matching input types",
|
||
raises=pa.ArrowNotImplementedError,
|
||
)
|
||
def test_duration_overflow_from_ndarray_containing_nat():
|
||
# GH52843
|
||
data_ts = pd.to_datetime([1, None])
|
||
data_td = pd.to_timedelta([1, None])
|
||
ser_ts = pd.Series(data_ts, dtype=ArrowDtype(pa.timestamp("ns")))
|
||
ser_td = pd.Series(data_td, dtype=ArrowDtype(pa.duration("ns")))
|
||
result = ser_ts + ser_td
|
||
expected = pd.Series([2, None], dtype=ArrowDtype(pa.timestamp("ns")))
|
||
tm.assert_series_equal(result, expected)
|