Inzynierka/Lib/site-packages/sklearn/datasets/tests/test_rcv1.py

69 lines
2.3 KiB
Python
Raw Permalink Normal View History

2023-06-02 12:51:02 +02:00
"""Test the rcv1 loader, if the data is available,
or if specifically requested via environment variable
(e.g. for CI jobs)."""
import scipy.sparse as sp
import numpy as np
from functools import partial
from sklearn.datasets.tests.test_common import check_return_X_y
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_array_equal
def test_fetch_rcv1(fetch_rcv1_fxt):
data1 = fetch_rcv1_fxt(shuffle=False)
X1, Y1 = data1.data, data1.target
cat_list, s1 = data1.target_names.tolist(), data1.sample_id
# test sparsity
assert sp.issparse(X1)
assert sp.issparse(Y1)
assert 60915113 == X1.data.size
assert 2606875 == Y1.data.size
# test shapes
assert (804414, 47236) == X1.shape
assert (804414, 103) == Y1.shape
assert (804414,) == s1.shape
assert 103 == len(cat_list)
# test descr
assert data1.DESCR.startswith(".. _rcv1_dataset:")
# test ordering of categories
first_categories = ["C11", "C12", "C13", "C14", "C15", "C151"]
assert_array_equal(first_categories, cat_list[:6])
# test number of sample for some categories
some_categories = ("GMIL", "E143", "CCAT")
number_non_zero_in_cat = (5, 1206, 381327)
for num, cat in zip(number_non_zero_in_cat, some_categories):
j = cat_list.index(cat)
assert num == Y1[:, j].data.size
# test shuffling and subset
data2 = fetch_rcv1_fxt(shuffle=True, subset="train", random_state=77)
X2, Y2 = data2.data, data2.target
s2 = data2.sample_id
# test return_X_y option
fetch_func = partial(fetch_rcv1_fxt, shuffle=False, subset="train")
check_return_X_y(data2, fetch_func)
# The first 23149 samples are the training samples
assert_array_equal(np.sort(s1[:23149]), np.sort(s2))
# test some precise values
some_sample_ids = (2286, 3274, 14042)
for sample_id in some_sample_ids:
idx1 = s1.tolist().index(sample_id)
idx2 = s2.tolist().index(sample_id)
feature_values_1 = X1[idx1, :].toarray()
feature_values_2 = X2[idx2, :].toarray()
assert_almost_equal(feature_values_1, feature_values_2)
target_values_1 = Y1[idx1, :].toarray()
target_values_2 = Y2[idx2, :].toarray()
assert_almost_equal(target_values_1, target_values_2)