33 lines
813 B
Python
33 lines
813 B
Python
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
from sklearn.utils import Bunch
|
||
|
|
||
|
|
||
|
def test_bunch_attribute_deprecation():
|
||
|
"""Check that bunch raises deprecation message with `__getattr__`."""
|
||
|
bunch = Bunch()
|
||
|
values = np.asarray([1, 2, 3])
|
||
|
msg = (
|
||
|
"Key: 'values', is deprecated in 1.3 and will be "
|
||
|
"removed in 1.5. Please use 'grid_values' instead"
|
||
|
)
|
||
|
bunch._set_deprecated(
|
||
|
values, new_key="grid_values", deprecated_key="values", warning_message=msg
|
||
|
)
|
||
|
|
||
|
with warnings.catch_warnings():
|
||
|
# Does not warn for "grid_values"
|
||
|
warnings.simplefilter("error")
|
||
|
v = bunch["grid_values"]
|
||
|
|
||
|
assert v is values
|
||
|
|
||
|
with pytest.warns(FutureWarning, match=msg):
|
||
|
# Warns for "values"
|
||
|
v = bunch["values"]
|
||
|
|
||
|
assert v is values
|