projektAI/venv/Lib/site-packages/sklearn/neighbors/tests/test_dist_metrics.py

228 lines
7.2 KiB
Python
Raw Normal View History

2021-06-06 22:13:05 +02:00
import itertools
import pickle
import numpy as np
from numpy.testing import assert_array_almost_equal
import pytest
from scipy.spatial.distance import cdist
from sklearn.neighbors import DistanceMetric
from sklearn.neighbors import BallTree
from sklearn.utils import check_random_state
from sklearn.utils._testing import assert_raises_regex
from sklearn.utils.fixes import sp_version, parse_version
def dist_func(x1, x2, p):
return np.sum((x1 - x2) ** p) ** (1. / p)
rng = check_random_state(0)
d = 4
n1 = 20
n2 = 25
X1 = rng.random_sample((n1, d)).astype('float64', copy=False)
X2 = rng.random_sample((n2, d)).astype('float64', copy=False)
# make boolean arrays: ones and zeros
X1_bool = X1.round(0)
X2_bool = X2.round(0)
V = rng.random_sample((d, d))
VI = np.dot(V, V.T)
BOOL_METRICS = ['matching', 'jaccard', 'dice',
'kulsinski', 'rogerstanimoto', 'russellrao',
'sokalmichener', 'sokalsneath']
METRICS_DEFAULT_PARAMS = {'euclidean': {},
'cityblock': {},
'minkowski': dict(p=(1, 1.5, 2, 3)),
'chebyshev': {},
'seuclidean': dict(V=(rng.random_sample(d),)),
'wminkowski': dict(p=(1, 1.5, 3),
w=(rng.random_sample(d),)),
'mahalanobis': dict(VI=(VI,)),
'hamming': {},
'canberra': {},
'braycurtis': {}}
@pytest.mark.parametrize('metric', METRICS_DEFAULT_PARAMS)
def test_cdist(metric):
argdict = METRICS_DEFAULT_PARAMS[metric]
keys = argdict.keys()
for vals in itertools.product(*argdict.values()):
kwargs = dict(zip(keys, vals))
if metric == "wminkowski":
if sp_version >= parse_version("1.8.0"):
pytest.skip("wminkowski will be removed in SciPy 1.8.0")
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
ExceptionToAssert = None
if sp_version >= parse_version("1.6.0"):
ExceptionToAssert = DeprecationWarning
with pytest.warns(ExceptionToAssert):
D_true = cdist(X1, X2, metric, **kwargs)
else:
D_true = cdist(X1, X2, metric, **kwargs)
check_cdist(metric, kwargs, D_true)
@pytest.mark.parametrize('metric', BOOL_METRICS)
def test_cdist_bool_metric(metric):
D_true = cdist(X1_bool, X2_bool, metric)
check_cdist_bool(metric, D_true)
def check_cdist(metric, kwargs, D_true):
dm = DistanceMetric.get_metric(metric, **kwargs)
D12 = dm.pairwise(X1, X2)
assert_array_almost_equal(D12, D_true)
def check_cdist_bool(metric, D_true):
dm = DistanceMetric.get_metric(metric)
D12 = dm.pairwise(X1_bool, X2_bool)
assert_array_almost_equal(D12, D_true)
@pytest.mark.parametrize('metric', METRICS_DEFAULT_PARAMS)
def test_pdist(metric):
argdict = METRICS_DEFAULT_PARAMS[metric]
keys = argdict.keys()
for vals in itertools.product(*argdict.values()):
kwargs = dict(zip(keys, vals))
if metric == "wminkowski":
if sp_version >= parse_version("1.8.0"):
pytest.skip("wminkowski will be removed in SciPy 1.8.0")
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
ExceptionToAssert = None
if sp_version >= parse_version("1.6.0"):
ExceptionToAssert = DeprecationWarning
with pytest.warns(ExceptionToAssert):
D_true = cdist(X1, X1, metric, **kwargs)
else:
D_true = cdist(X1, X1, metric, **kwargs)
check_pdist(metric, kwargs, D_true)
@pytest.mark.parametrize('metric', BOOL_METRICS)
def test_pdist_bool_metrics(metric):
D_true = cdist(X1_bool, X1_bool, metric)
check_pdist_bool(metric, D_true)
def check_pdist(metric, kwargs, D_true):
dm = DistanceMetric.get_metric(metric, **kwargs)
D12 = dm.pairwise(X1)
assert_array_almost_equal(D12, D_true)
def check_pdist_bool(metric, D_true):
dm = DistanceMetric.get_metric(metric)
D12 = dm.pairwise(X1_bool)
# Based on https://github.com/scipy/scipy/pull/7373
# When comparing two all-zero vectors, scipy>=1.2.0 jaccard metric
# was changed to return 0, instead of nan.
if metric == 'jaccard' and sp_version < parse_version('1.2.0'):
D_true[np.isnan(D_true)] = 0
assert_array_almost_equal(D12, D_true)
@pytest.mark.parametrize('metric', METRICS_DEFAULT_PARAMS)
def test_pickle(metric):
argdict = METRICS_DEFAULT_PARAMS[metric]
keys = argdict.keys()
for vals in itertools.product(*argdict.values()):
kwargs = dict(zip(keys, vals))
check_pickle(metric, kwargs)
@pytest.mark.parametrize('metric', BOOL_METRICS)
def test_pickle_bool_metrics(metric):
dm = DistanceMetric.get_metric(metric)
D1 = dm.pairwise(X1_bool)
dm2 = pickle.loads(pickle.dumps(dm))
D2 = dm2.pairwise(X1_bool)
assert_array_almost_equal(D1, D2)
def check_pickle(metric, kwargs):
dm = DistanceMetric.get_metric(metric, **kwargs)
D1 = dm.pairwise(X1)
dm2 = pickle.loads(pickle.dumps(dm))
D2 = dm2.pairwise(X1)
assert_array_almost_equal(D1, D2)
def test_haversine_metric():
def haversine_slow(x1, x2):
return 2 * np.arcsin(np.sqrt(np.sin(0.5 * (x1[0] - x2[0])) ** 2
+ np.cos(x1[0]) * np.cos(x2[0]) *
np.sin(0.5 * (x1[1] - x2[1])) ** 2))
X = np.random.random((10, 2))
haversine = DistanceMetric.get_metric("haversine")
D1 = haversine.pairwise(X)
D2 = np.zeros_like(D1)
for i, x1 in enumerate(X):
for j, x2 in enumerate(X):
D2[i, j] = haversine_slow(x1, x2)
assert_array_almost_equal(D1, D2)
assert_array_almost_equal(haversine.dist_to_rdist(D1),
np.sin(0.5 * D2) ** 2)
def test_pyfunc_metric():
X = np.random.random((10, 3))
euclidean = DistanceMetric.get_metric("euclidean")
pyfunc = DistanceMetric.get_metric("pyfunc", func=dist_func, p=2)
# Check if both callable metric and predefined metric initialized
# DistanceMetric object is picklable
euclidean_pkl = pickle.loads(pickle.dumps(euclidean))
pyfunc_pkl = pickle.loads(pickle.dumps(pyfunc))
D1 = euclidean.pairwise(X)
D2 = pyfunc.pairwise(X)
D1_pkl = euclidean_pkl.pairwise(X)
D2_pkl = pyfunc_pkl.pairwise(X)
assert_array_almost_equal(D1, D2)
assert_array_almost_equal(D1_pkl, D2_pkl)
def test_bad_pyfunc_metric():
def wrong_distance(x, y):
return "1"
X = np.ones((5, 2))
assert_raises_regex(TypeError,
"Custom distance function must accept two vectors",
BallTree, X, metric=wrong_distance)
def test_input_data_size():
# Regression test for #6288
# Previously, a metric requiring a particular input dimension would fail
def custom_metric(x, y):
assert x.shape[0] == 3
return np.sum((x - y) ** 2)
rng = check_random_state(0)
X = rng.rand(10, 3)
pyfunc = DistanceMetric.get_metric("pyfunc", func=custom_metric)
eucl = DistanceMetric.get_metric("euclidean")
assert_array_almost_equal(pyfunc.pairwise(X), eucl.pairwise(X) ** 2)