3RNN/Lib/site-packages/sklearn/neighbors/tests/test_kd_tree.py
2024-05-26 19:49:15 +02:00

101 lines
3.8 KiB
Python

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_equal
from sklearn.neighbors._kd_tree import KDTree, KDTree32, KDTree64
from sklearn.neighbors.tests.test_ball_tree import get_dataset_for_binary_tree
from sklearn.utils.parallel import Parallel, delayed
DIMENSION = 3
METRICS = {"euclidean": {}, "manhattan": {}, "chebyshev": {}, "minkowski": dict(p=3)}
KD_TREE_CLASSES = [
KDTree64,
KDTree32,
]
def test_KDTree_is_KDTree64_subclass():
assert issubclass(KDTree, KDTree64)
@pytest.mark.parametrize("BinarySearchTree", KD_TREE_CLASSES)
def test_array_object_type(BinarySearchTree):
"""Check that we do not accept object dtype array."""
X = np.array([(1, 2, 3), (2, 5), (5, 5, 1, 2)], dtype=object)
with pytest.raises(ValueError, match="setting an array element with a sequence"):
BinarySearchTree(X)
@pytest.mark.parametrize("BinarySearchTree", KD_TREE_CLASSES)
def test_kdtree_picklable_with_joblib(BinarySearchTree):
"""Make sure that KDTree queries work when joblib memmaps.
Non-regression test for #21685 and #21228."""
rng = np.random.RandomState(0)
X = rng.random_sample((10, 3))
tree = BinarySearchTree(X, leaf_size=2)
# Call Parallel with max_nbytes=1 to trigger readonly memory mapping that
# use to raise "ValueError: buffer source array is read-only" in a previous
# version of the Cython code.
Parallel(n_jobs=2, max_nbytes=1)(delayed(tree.query)(data) for data in 2 * [X])
@pytest.mark.parametrize("metric", METRICS)
def test_kd_tree_numerical_consistency(global_random_seed, metric):
# Results on float64 and float32 versions of a dataset must be
# numerically close.
X_64, X_32, Y_64, Y_32 = get_dataset_for_binary_tree(
random_seed=global_random_seed, features=50
)
metric_params = METRICS.get(metric, {})
kd_64 = KDTree64(X_64, leaf_size=2, metric=metric, **metric_params)
kd_32 = KDTree32(X_32, leaf_size=2, metric=metric, **metric_params)
# Test consistency with respect to the `query` method
k = 4
dist_64, ind_64 = kd_64.query(Y_64, k=k)
dist_32, ind_32 = kd_32.query(Y_32, k=k)
assert_allclose(dist_64, dist_32, rtol=1e-5)
assert_equal(ind_64, ind_32)
assert dist_64.dtype == np.float64
assert dist_32.dtype == np.float32
# Test consistency with respect to the `query_radius` method
r = 2.38
ind_64 = kd_64.query_radius(Y_64, r=r)
ind_32 = kd_32.query_radius(Y_32, r=r)
for _ind64, _ind32 in zip(ind_64, ind_32):
assert_equal(_ind64, _ind32)
# Test consistency with respect to the `query_radius` method
# with return distances being true
ind_64, dist_64 = kd_64.query_radius(Y_64, r=r, return_distance=True)
ind_32, dist_32 = kd_32.query_radius(Y_32, r=r, return_distance=True)
for _ind64, _ind32, _dist_64, _dist_32 in zip(ind_64, ind_32, dist_64, dist_32):
assert_equal(_ind64, _ind32)
assert_allclose(_dist_64, _dist_32, rtol=1e-5)
assert _dist_64.dtype == np.float64
assert _dist_32.dtype == np.float32
@pytest.mark.parametrize("metric", METRICS)
def test_kernel_density_numerical_consistency(global_random_seed, metric):
# Test consistency with respect to the `kernel_density` method
X_64, X_32, Y_64, Y_32 = get_dataset_for_binary_tree(random_seed=global_random_seed)
metric_params = METRICS.get(metric, {})
kd_64 = KDTree64(X_64, leaf_size=2, metric=metric, **metric_params)
kd_32 = KDTree32(X_32, leaf_size=2, metric=metric, **metric_params)
kernel = "gaussian"
h = 0.1
density64 = kd_64.kernel_density(Y_64, h=h, kernel=kernel, breadth_first=True)
density32 = kd_32.kernel_density(Y_32, h=h, kernel=kernel, breadth_first=True)
assert_allclose(density64, density32, rtol=1e-5)
assert density64.dtype == np.float64
assert density32.dtype == np.float32