285 lines
8.0 KiB
Python
285 lines
8.0 KiB
Python
|
from io import BytesIO
|
||
|
import textwrap
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from sklearn.datasets._arff_parser import (
|
||
|
_liac_arff_parser,
|
||
|
_pandas_arff_parser,
|
||
|
_post_process_frame,
|
||
|
load_arff_from_gzip_file,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"feature_names, target_names",
|
||
|
[
|
||
|
(
|
||
|
[
|
||
|
"col_int_as_integer",
|
||
|
"col_int_as_numeric",
|
||
|
"col_float_as_real",
|
||
|
"col_float_as_numeric",
|
||
|
],
|
||
|
["col_categorical", "col_string"],
|
||
|
),
|
||
|
(
|
||
|
[
|
||
|
"col_int_as_integer",
|
||
|
"col_int_as_numeric",
|
||
|
"col_float_as_real",
|
||
|
"col_float_as_numeric",
|
||
|
],
|
||
|
["col_categorical"],
|
||
|
),
|
||
|
(
|
||
|
[
|
||
|
"col_int_as_integer",
|
||
|
"col_int_as_numeric",
|
||
|
"col_float_as_real",
|
||
|
"col_float_as_numeric",
|
||
|
],
|
||
|
[],
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
def test_post_process_frame(feature_names, target_names):
|
||
|
"""Check the behaviour of the post-processing function for splitting a dataframe."""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
|
||
|
X_original = pd.DataFrame(
|
||
|
{
|
||
|
"col_int_as_integer": [1, 2, 3],
|
||
|
"col_int_as_numeric": [1, 2, 3],
|
||
|
"col_float_as_real": [1.0, 2.0, 3.0],
|
||
|
"col_float_as_numeric": [1.0, 2.0, 3.0],
|
||
|
"col_categorical": ["a", "b", "c"],
|
||
|
"col_string": ["a", "b", "c"],
|
||
|
}
|
||
|
)
|
||
|
|
||
|
X, y = _post_process_frame(X_original, feature_names, target_names)
|
||
|
assert isinstance(X, pd.DataFrame)
|
||
|
if len(target_names) >= 2:
|
||
|
assert isinstance(y, pd.DataFrame)
|
||
|
elif len(target_names) == 1:
|
||
|
assert isinstance(y, pd.Series)
|
||
|
else:
|
||
|
assert y is None
|
||
|
|
||
|
|
||
|
def test_load_arff_from_gzip_file_error_parser():
|
||
|
"""An error will be raised if the parser is not known."""
|
||
|
# None of the input parameters are required to be accurate since the check
|
||
|
# of the parser will be carried out first.
|
||
|
|
||
|
err_msg = "Unknown parser: 'xxx'. Should be 'liac-arff' or 'pandas'"
|
||
|
with pytest.raises(ValueError, match=err_msg):
|
||
|
load_arff_from_gzip_file("xxx", "xxx", "xxx", "xxx", "xxx", "xxx")
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("parser_func", [_liac_arff_parser, _pandas_arff_parser])
|
||
|
def test_pandas_arff_parser_strip_single_quotes(parser_func):
|
||
|
"""Check that we properly strip single quotes from the data."""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
|
||
|
arff_file = BytesIO(
|
||
|
textwrap.dedent(
|
||
|
"""
|
||
|
@relation 'toy'
|
||
|
@attribute 'cat_single_quote' {'A', 'B', 'C'}
|
||
|
@attribute 'str_single_quote' string
|
||
|
@attribute 'str_nested_quote' string
|
||
|
@attribute 'class' numeric
|
||
|
@data
|
||
|
'A','some text','\"expect double quotes\"',0
|
||
|
"""
|
||
|
).encode("utf-8")
|
||
|
)
|
||
|
|
||
|
columns_info = {
|
||
|
"cat_single_quote": {
|
||
|
"data_type": "nominal",
|
||
|
"name": "cat_single_quote",
|
||
|
},
|
||
|
"str_single_quote": {
|
||
|
"data_type": "string",
|
||
|
"name": "str_single_quote",
|
||
|
},
|
||
|
"str_nested_quote": {
|
||
|
"data_type": "string",
|
||
|
"name": "str_nested_quote",
|
||
|
},
|
||
|
"class": {
|
||
|
"data_type": "numeric",
|
||
|
"name": "class",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
feature_names = [
|
||
|
"cat_single_quote",
|
||
|
"str_single_quote",
|
||
|
"str_nested_quote",
|
||
|
]
|
||
|
target_names = ["class"]
|
||
|
|
||
|
# We don't strip single quotes for string columns with the pandas parser.
|
||
|
expected_values = {
|
||
|
"cat_single_quote": "A",
|
||
|
"str_single_quote": (
|
||
|
"some text" if parser_func is _liac_arff_parser else "'some text'"
|
||
|
),
|
||
|
"str_nested_quote": (
|
||
|
'"expect double quotes"'
|
||
|
if parser_func is _liac_arff_parser
|
||
|
else "'\"expect double quotes\"'"
|
||
|
),
|
||
|
"class": 0,
|
||
|
}
|
||
|
|
||
|
_, _, frame, _ = parser_func(
|
||
|
arff_file,
|
||
|
output_arrays_type="pandas",
|
||
|
openml_columns_info=columns_info,
|
||
|
feature_names_to_select=feature_names,
|
||
|
target_names_to_select=target_names,
|
||
|
)
|
||
|
|
||
|
assert frame.columns.tolist() == feature_names + target_names
|
||
|
pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("parser_func", [_liac_arff_parser, _pandas_arff_parser])
|
||
|
def test_pandas_arff_parser_strip_double_quotes(parser_func):
|
||
|
"""Check that we properly strip double quotes from the data."""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
|
||
|
arff_file = BytesIO(
|
||
|
textwrap.dedent(
|
||
|
"""
|
||
|
@relation 'toy'
|
||
|
@attribute 'cat_double_quote' {"A", "B", "C"}
|
||
|
@attribute 'str_double_quote' string
|
||
|
@attribute 'str_nested_quote' string
|
||
|
@attribute 'class' numeric
|
||
|
@data
|
||
|
"A","some text","\'expect double quotes\'",0
|
||
|
"""
|
||
|
).encode("utf-8")
|
||
|
)
|
||
|
|
||
|
columns_info = {
|
||
|
"cat_double_quote": {
|
||
|
"data_type": "nominal",
|
||
|
"name": "cat_double_quote",
|
||
|
},
|
||
|
"str_double_quote": {
|
||
|
"data_type": "string",
|
||
|
"name": "str_double_quote",
|
||
|
},
|
||
|
"str_nested_quote": {
|
||
|
"data_type": "string",
|
||
|
"name": "str_nested_quote",
|
||
|
},
|
||
|
"class": {
|
||
|
"data_type": "numeric",
|
||
|
"name": "class",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
feature_names = [
|
||
|
"cat_double_quote",
|
||
|
"str_double_quote",
|
||
|
"str_nested_quote",
|
||
|
]
|
||
|
target_names = ["class"]
|
||
|
|
||
|
expected_values = {
|
||
|
"cat_double_quote": "A",
|
||
|
"str_double_quote": "some text",
|
||
|
"str_nested_quote": "'expect double quotes'",
|
||
|
"class": 0,
|
||
|
}
|
||
|
|
||
|
_, _, frame, _ = parser_func(
|
||
|
arff_file,
|
||
|
output_arrays_type="pandas",
|
||
|
openml_columns_info=columns_info,
|
||
|
feature_names_to_select=feature_names,
|
||
|
target_names_to_select=target_names,
|
||
|
)
|
||
|
|
||
|
assert frame.columns.tolist() == feature_names + target_names
|
||
|
pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"parser_func",
|
||
|
[
|
||
|
# internal quotes are not considered to follow the ARFF spec in LIAC ARFF
|
||
|
pytest.param(_liac_arff_parser, marks=pytest.mark.xfail),
|
||
|
_pandas_arff_parser,
|
||
|
],
|
||
|
)
|
||
|
def test_pandas_arff_parser_strip_no_quotes(parser_func):
|
||
|
"""Check that we properly parse with no quotes characters."""
|
||
|
pd = pytest.importorskip("pandas")
|
||
|
|
||
|
arff_file = BytesIO(
|
||
|
textwrap.dedent(
|
||
|
"""
|
||
|
@relation 'toy'
|
||
|
@attribute 'cat_without_quote' {A, B, C}
|
||
|
@attribute 'str_without_quote' string
|
||
|
@attribute 'str_internal_quote' string
|
||
|
@attribute 'class' numeric
|
||
|
@data
|
||
|
A,some text,'internal' quote,0
|
||
|
"""
|
||
|
).encode("utf-8")
|
||
|
)
|
||
|
|
||
|
columns_info = {
|
||
|
"cat_without_quote": {
|
||
|
"data_type": "nominal",
|
||
|
"name": "cat_without_quote",
|
||
|
},
|
||
|
"str_without_quote": {
|
||
|
"data_type": "string",
|
||
|
"name": "str_without_quote",
|
||
|
},
|
||
|
"str_internal_quote": {
|
||
|
"data_type": "string",
|
||
|
"name": "str_internal_quote",
|
||
|
},
|
||
|
"class": {
|
||
|
"data_type": "numeric",
|
||
|
"name": "class",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
feature_names = [
|
||
|
"cat_without_quote",
|
||
|
"str_without_quote",
|
||
|
"str_internal_quote",
|
||
|
]
|
||
|
target_names = ["class"]
|
||
|
|
||
|
expected_values = {
|
||
|
"cat_without_quote": "A",
|
||
|
"str_without_quote": "some text",
|
||
|
"str_internal_quote": "'internal' quote",
|
||
|
"class": 0,
|
||
|
}
|
||
|
|
||
|
_, _, frame, _ = parser_func(
|
||
|
arff_file,
|
||
|
output_arrays_type="pandas",
|
||
|
openml_columns_info=columns_info,
|
||
|
feature_names_to_select=feature_names,
|
||
|
target_names_to_select=target_names,
|
||
|
)
|
||
|
|
||
|
assert frame.columns.tolist() == feature_names + target_names
|
||
|
pd.testing.assert_series_equal(frame.iloc[0], pd.Series(expected_values, name=0))
|