110 lines
2.6 KiB
Python
110 lines
2.6 KiB
Python
|
import contextlib
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
import pandas as pd
|
||
|
import pandas._testing as tm
|
||
|
from pandas.core import accessor
|
||
|
|
||
|
|
||
|
def test_dirname_mixin():
|
||
|
# GH37173
|
||
|
|
||
|
class X(accessor.DirNamesMixin):
|
||
|
x = 1
|
||
|
y: int
|
||
|
|
||
|
def __init__(self):
|
||
|
self.z = 3
|
||
|
|
||
|
result = [attr_name for attr_name in dir(X()) if not attr_name.startswith("_")]
|
||
|
|
||
|
assert result == ["x", "z"]
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def ensure_removed(obj, attr):
|
||
|
"""Ensure that an attribute added to 'obj' during the test is
|
||
|
removed when we're done
|
||
|
"""
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
try:
|
||
|
delattr(obj, attr)
|
||
|
except AttributeError:
|
||
|
pass
|
||
|
obj._accessors.discard(attr)
|
||
|
|
||
|
|
||
|
class MyAccessor:
|
||
|
def __init__(self, obj):
|
||
|
self.obj = obj
|
||
|
self.item = "item"
|
||
|
|
||
|
@property
|
||
|
def prop(self):
|
||
|
return self.item
|
||
|
|
||
|
def method(self):
|
||
|
return self.item
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"obj, registrar",
|
||
|
[
|
||
|
(pd.Series, pd.api.extensions.register_series_accessor),
|
||
|
(pd.DataFrame, pd.api.extensions.register_dataframe_accessor),
|
||
|
(pd.Index, pd.api.extensions.register_index_accessor),
|
||
|
],
|
||
|
)
|
||
|
def test_register(obj, registrar):
|
||
|
with ensure_removed(obj, "mine"):
|
||
|
before = set(dir(obj))
|
||
|
registrar("mine")(MyAccessor)
|
||
|
o = obj([]) if obj is not pd.Series else obj([], dtype=object)
|
||
|
assert o.mine.prop == "item"
|
||
|
after = set(dir(obj))
|
||
|
assert (before ^ after) == {"mine"}
|
||
|
assert "mine" in obj._accessors
|
||
|
|
||
|
|
||
|
def test_accessor_works():
|
||
|
with ensure_removed(pd.Series, "mine"):
|
||
|
pd.api.extensions.register_series_accessor("mine")(MyAccessor)
|
||
|
|
||
|
s = pd.Series([1, 2])
|
||
|
assert s.mine.obj is s
|
||
|
|
||
|
assert s.mine.prop == "item"
|
||
|
assert s.mine.method() == "item"
|
||
|
|
||
|
|
||
|
def test_overwrite_warns():
|
||
|
# Need to restore mean
|
||
|
mean = pd.Series.mean
|
||
|
try:
|
||
|
with tm.assert_produces_warning(UserWarning) as w:
|
||
|
pd.api.extensions.register_series_accessor("mean")(MyAccessor)
|
||
|
s = pd.Series([1, 2])
|
||
|
assert s.mean.prop == "item"
|
||
|
msg = str(w[0].message)
|
||
|
assert "mean" in msg
|
||
|
assert "MyAccessor" in msg
|
||
|
assert "Series" in msg
|
||
|
finally:
|
||
|
pd.Series.mean = mean
|
||
|
|
||
|
|
||
|
def test_raises_attribute_error():
|
||
|
|
||
|
with ensure_removed(pd.Series, "bad"):
|
||
|
|
||
|
@pd.api.extensions.register_series_accessor("bad")
|
||
|
class Bad:
|
||
|
def __init__(self, data):
|
||
|
raise AttributeError("whoops")
|
||
|
|
||
|
with pytest.raises(AttributeError, match="whoops"):
|
||
|
pd.Series([], dtype=object).bad
|