230 lines
7.6 KiB
Python
230 lines
7.6 KiB
Python
"""This test for the LFW require medium-size data downloading and processing
|
|
|
|
If the data has not been already downloaded by running the examples,
|
|
the tests won't run (skipped).
|
|
|
|
If the test are run, the first execution will be long (typically a bit
|
|
more than a couple of minutes) but as the dataset loader is leveraging
|
|
joblib, successive runs will be fast (less than 200ms).
|
|
"""
|
|
|
|
import random
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from sklearn.datasets import fetch_lfw_pairs, fetch_lfw_people
|
|
from sklearn.datasets.tests.test_common import check_return_X_y
|
|
from sklearn.utils._testing import assert_array_equal
|
|
|
|
FAKE_NAMES = [
|
|
"Abdelatif_Smith",
|
|
"Abhati_Kepler",
|
|
"Camara_Alvaro",
|
|
"Chen_Dupont",
|
|
"John_Lee",
|
|
"Lin_Bauman",
|
|
"Onur_Lopez",
|
|
]
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def mock_empty_data_home(tmp_path_factory):
|
|
data_dir = tmp_path_factory.mktemp("scikit_learn_empty_test")
|
|
|
|
yield data_dir
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def mock_data_home(tmp_path_factory):
|
|
"""Test fixture run once and common to all tests of this module"""
|
|
Image = pytest.importorskip("PIL.Image")
|
|
|
|
data_dir = tmp_path_factory.mktemp("scikit_learn_lfw_test")
|
|
lfw_home = data_dir / "lfw_home"
|
|
lfw_home.mkdir(parents=True, exist_ok=True)
|
|
|
|
random_state = random.Random(42)
|
|
np_rng = np.random.RandomState(42)
|
|
|
|
# generate some random jpeg files for each person
|
|
counts = {}
|
|
for name in FAKE_NAMES:
|
|
folder_name = lfw_home / "lfw_funneled" / name
|
|
folder_name.mkdir(parents=True, exist_ok=True)
|
|
|
|
n_faces = np_rng.randint(1, 5)
|
|
counts[name] = n_faces
|
|
for i in range(n_faces):
|
|
file_path = folder_name / (name + "_%04d.jpg" % i)
|
|
uniface = np_rng.randint(0, 255, size=(250, 250, 3))
|
|
img = Image.fromarray(uniface.astype(np.uint8))
|
|
img.save(file_path)
|
|
|
|
# add some random file pollution to test robustness
|
|
(lfw_home / "lfw_funneled" / ".test.swp").write_bytes(
|
|
b"Text file to be ignored by the dataset loader."
|
|
)
|
|
|
|
# generate some pairing metadata files using the same format as LFW
|
|
with open(lfw_home / "pairsDevTrain.txt", "wb") as f:
|
|
f.write(b"10\n")
|
|
more_than_two = [name for name, count in counts.items() if count >= 2]
|
|
for i in range(5):
|
|
name = random_state.choice(more_than_two)
|
|
first, second = random_state.sample(range(counts[name]), 2)
|
|
f.write(("%s\t%d\t%d\n" % (name, first, second)).encode())
|
|
|
|
for i in range(5):
|
|
first_name, second_name = random_state.sample(FAKE_NAMES, 2)
|
|
first_index = np_rng.choice(np.arange(counts[first_name]))
|
|
second_index = np_rng.choice(np.arange(counts[second_name]))
|
|
f.write(
|
|
(
|
|
"%s\t%d\t%s\t%d\n"
|
|
% (first_name, first_index, second_name, second_index)
|
|
).encode()
|
|
)
|
|
|
|
(lfw_home / "pairsDevTest.txt").write_bytes(
|
|
b"Fake place holder that won't be tested"
|
|
)
|
|
(lfw_home / "pairs.txt").write_bytes(b"Fake place holder that won't be tested")
|
|
|
|
yield data_dir
|
|
|
|
|
|
def test_load_empty_lfw_people(mock_empty_data_home):
|
|
with pytest.raises(OSError):
|
|
fetch_lfw_people(data_home=mock_empty_data_home, download_if_missing=False)
|
|
|
|
|
|
def test_load_fake_lfw_people(mock_data_home):
|
|
lfw_people = fetch_lfw_people(
|
|
data_home=mock_data_home, min_faces_per_person=3, download_if_missing=False
|
|
)
|
|
|
|
# The data is croped around the center as a rectangular bounding box
|
|
# around the face. Colors are converted to gray levels:
|
|
assert lfw_people.images.shape == (10, 62, 47)
|
|
assert lfw_people.data.shape == (10, 2914)
|
|
|
|
# the target is array of person integer ids
|
|
assert_array_equal(lfw_people.target, [2, 0, 1, 0, 2, 0, 2, 1, 1, 2])
|
|
|
|
# names of the persons can be found using the target_names array
|
|
expected_classes = ["Abdelatif Smith", "Abhati Kepler", "Onur Lopez"]
|
|
assert_array_equal(lfw_people.target_names, expected_classes)
|
|
|
|
# It is possible to ask for the original data without any croping or color
|
|
# conversion and not limit on the number of picture per person
|
|
lfw_people = fetch_lfw_people(
|
|
data_home=mock_data_home,
|
|
resize=None,
|
|
slice_=None,
|
|
color=True,
|
|
download_if_missing=False,
|
|
)
|
|
assert lfw_people.images.shape == (17, 250, 250, 3)
|
|
assert lfw_people.DESCR.startswith(".. _labeled_faces_in_the_wild_dataset:")
|
|
|
|
# the ids and class names are the same as previously
|
|
assert_array_equal(
|
|
lfw_people.target, [0, 0, 1, 6, 5, 6, 3, 6, 0, 3, 6, 1, 2, 4, 5, 1, 2]
|
|
)
|
|
assert_array_equal(
|
|
lfw_people.target_names,
|
|
[
|
|
"Abdelatif Smith",
|
|
"Abhati Kepler",
|
|
"Camara Alvaro",
|
|
"Chen Dupont",
|
|
"John Lee",
|
|
"Lin Bauman",
|
|
"Onur Lopez",
|
|
],
|
|
)
|
|
|
|
# test return_X_y option
|
|
fetch_func = partial(
|
|
fetch_lfw_people,
|
|
data_home=mock_data_home,
|
|
resize=None,
|
|
slice_=None,
|
|
color=True,
|
|
download_if_missing=False,
|
|
)
|
|
check_return_X_y(lfw_people, fetch_func)
|
|
|
|
|
|
def test_load_fake_lfw_people_too_restrictive(mock_data_home):
|
|
with pytest.raises(ValueError):
|
|
fetch_lfw_people(
|
|
data_home=mock_data_home,
|
|
min_faces_per_person=100,
|
|
download_if_missing=False,
|
|
)
|
|
|
|
|
|
def test_load_empty_lfw_pairs(mock_empty_data_home):
|
|
with pytest.raises(OSError):
|
|
fetch_lfw_pairs(data_home=mock_empty_data_home, download_if_missing=False)
|
|
|
|
|
|
def test_load_fake_lfw_pairs(mock_data_home):
|
|
lfw_pairs_train = fetch_lfw_pairs(
|
|
data_home=mock_data_home, download_if_missing=False
|
|
)
|
|
|
|
# The data is croped around the center as a rectangular bounding box
|
|
# around the face. Colors are converted to gray levels:
|
|
assert lfw_pairs_train.pairs.shape == (10, 2, 62, 47)
|
|
|
|
# the target is whether the person is the same or not
|
|
assert_array_equal(lfw_pairs_train.target, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
|
|
|
|
# names of the persons can be found using the target_names array
|
|
expected_classes = ["Different persons", "Same person"]
|
|
assert_array_equal(lfw_pairs_train.target_names, expected_classes)
|
|
|
|
# It is possible to ask for the original data without any croping or color
|
|
# conversion
|
|
lfw_pairs_train = fetch_lfw_pairs(
|
|
data_home=mock_data_home,
|
|
resize=None,
|
|
slice_=None,
|
|
color=True,
|
|
download_if_missing=False,
|
|
)
|
|
assert lfw_pairs_train.pairs.shape == (10, 2, 250, 250, 3)
|
|
|
|
# the ids and class names are the same as previously
|
|
assert_array_equal(lfw_pairs_train.target, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
|
|
assert_array_equal(lfw_pairs_train.target_names, expected_classes)
|
|
|
|
assert lfw_pairs_train.DESCR.startswith(".. _labeled_faces_in_the_wild_dataset:")
|
|
|
|
|
|
def test_fetch_lfw_people_internal_cropping(mock_data_home):
|
|
"""Check that we properly crop the images.
|
|
|
|
Non-regression test for:
|
|
https://github.com/scikit-learn/scikit-learn/issues/24942
|
|
"""
|
|
# If cropping was not done properly and we don't resize the images, the images would
|
|
# have their original size (250x250) and the image would not fit in the NumPy array
|
|
# pre-allocated based on `slice_` parameter.
|
|
slice_ = (slice(70, 195), slice(78, 172))
|
|
lfw = fetch_lfw_people(
|
|
data_home=mock_data_home,
|
|
min_faces_per_person=3,
|
|
download_if_missing=False,
|
|
resize=None,
|
|
slice_=slice_,
|
|
)
|
|
assert lfw.images[0].shape == (
|
|
slice_[0].stop - slice_[0].start,
|
|
slice_[1].stop - slice_[1].start,
|
|
)
|