124 lines
3.7 KiB
Python
124 lines
3.7 KiB
Python
|
# Authors: Christian Lorentzen <lorentzen.ch@gmail.com>
|
||
|
#
|
||
|
# License: BSD 3 clause
|
||
|
#
|
||
|
# TODO(1.3): remove file
|
||
|
import numpy as np
|
||
|
from numpy.testing import (
|
||
|
assert_allclose,
|
||
|
assert_array_equal,
|
||
|
)
|
||
|
from scipy.optimize import check_grad
|
||
|
import pytest
|
||
|
|
||
|
from sklearn._loss.glm_distribution import (
|
||
|
TweedieDistribution,
|
||
|
NormalDistribution,
|
||
|
PoissonDistribution,
|
||
|
GammaDistribution,
|
||
|
InverseGaussianDistribution,
|
||
|
DistributionBoundary,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"family, expected",
|
||
|
[
|
||
|
(NormalDistribution(), [True, True, True]),
|
||
|
(PoissonDistribution(), [False, True, True]),
|
||
|
(TweedieDistribution(power=1.5), [False, True, True]),
|
||
|
(GammaDistribution(), [False, False, True]),
|
||
|
(InverseGaussianDistribution(), [False, False, True]),
|
||
|
(TweedieDistribution(power=4.5), [False, False, True]),
|
||
|
],
|
||
|
)
|
||
|
def test_family_bounds(family, expected):
|
||
|
"""Test the valid range of distributions at -1, 0, 1."""
|
||
|
result = family.in_y_range([-1, 0, 1])
|
||
|
assert_array_equal(result, expected)
|
||
|
|
||
|
|
||
|
def test_invalid_distribution_bound():
|
||
|
dist = TweedieDistribution()
|
||
|
dist._lower_bound = 0
|
||
|
with pytest.raises(TypeError, match="must be of type DistributionBoundary"):
|
||
|
dist.in_y_range([-1, 0, 1])
|
||
|
|
||
|
|
||
|
def test_tweedie_distribution_power():
|
||
|
msg = "distribution is only defined for power<=0 and power>=1"
|
||
|
with pytest.raises(ValueError, match=msg):
|
||
|
TweedieDistribution(power=0.5)
|
||
|
|
||
|
with pytest.raises(TypeError, match="must be a real number"):
|
||
|
TweedieDistribution(power=1j)
|
||
|
|
||
|
with pytest.raises(TypeError, match="must be a real number"):
|
||
|
dist = TweedieDistribution()
|
||
|
dist.power = 1j
|
||
|
|
||
|
dist = TweedieDistribution()
|
||
|
assert isinstance(dist._lower_bound, DistributionBoundary)
|
||
|
|
||
|
assert dist._lower_bound.inclusive is False
|
||
|
dist.power = 1
|
||
|
assert dist._lower_bound.value == 0.0
|
||
|
assert dist._lower_bound.inclusive is True
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"family, chk_values",
|
||
|
[
|
||
|
(NormalDistribution(), [-1.5, -0.1, 0.1, 2.5]),
|
||
|
(PoissonDistribution(), [0.1, 1.5]),
|
||
|
(GammaDistribution(), [0.1, 1.5]),
|
||
|
(InverseGaussianDistribution(), [0.1, 1.5]),
|
||
|
(TweedieDistribution(power=-2.5), [0.1, 1.5]),
|
||
|
(TweedieDistribution(power=-1), [0.1, 1.5]),
|
||
|
(TweedieDistribution(power=1.5), [0.1, 1.5]),
|
||
|
(TweedieDistribution(power=2.5), [0.1, 1.5]),
|
||
|
(TweedieDistribution(power=-4), [0.1, 1.5]),
|
||
|
],
|
||
|
)
|
||
|
def test_deviance_zero(family, chk_values):
|
||
|
"""Test deviance(y,y) = 0 for different families."""
|
||
|
for x in chk_values:
|
||
|
assert_allclose(family.deviance(x, x), 0, atol=1e-9)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"family",
|
||
|
[
|
||
|
NormalDistribution(),
|
||
|
PoissonDistribution(),
|
||
|
GammaDistribution(),
|
||
|
InverseGaussianDistribution(),
|
||
|
TweedieDistribution(power=-2.5),
|
||
|
TweedieDistribution(power=-1),
|
||
|
TweedieDistribution(power=1.5),
|
||
|
TweedieDistribution(power=2.5),
|
||
|
TweedieDistribution(power=-4),
|
||
|
],
|
||
|
ids=lambda x: x.__class__.__name__,
|
||
|
)
|
||
|
def test_deviance_derivative(family, global_random_seed):
|
||
|
"""Test deviance derivative for different families."""
|
||
|
rng = np.random.RandomState(global_random_seed)
|
||
|
y_true = rng.rand(10)
|
||
|
# make data positive
|
||
|
y_true += np.abs(y_true.min()) + 1e-2
|
||
|
|
||
|
y_pred = y_true + np.fmax(rng.rand(10), 0.0)
|
||
|
|
||
|
dev = family.deviance(y_true, y_pred)
|
||
|
assert isinstance(dev, float)
|
||
|
dev_derivative = family.deviance_derivative(y_true, y_pred)
|
||
|
assert dev_derivative.shape == y_pred.shape
|
||
|
|
||
|
err = check_grad(
|
||
|
lambda y_pred: family.deviance(y_true, y_pred),
|
||
|
lambda y_pred: family.deviance_derivative(y_true, y_pred),
|
||
|
y_pred,
|
||
|
) / np.linalg.norm(dev_derivative)
|
||
|
assert abs(err) < 3e-6
|