projektAI/venv/Lib/site-packages/mlxtend/evaluate/permutation.py
2021-06-06 22:13:05 +02:00

141 lines
4.9 KiB
Python

# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
#
# Nonparametric Permutation Test
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause
import numpy as np
from itertools import combinations
from math import factorial
try:
from nose.tools import nottest
except ImportError:
# Use a no-op decorator if nose is not available
def nottest(f):
return f
# decorator to prevent nose to consider
# this as a unit test due to "test" in the name
@nottest
def permutation_test(x, y, func='x_mean != y_mean', method='exact',
num_rounds=1000, seed=None):
"""
Nonparametric permutation test
Parameters
-------------
x : list or numpy array with shape (n_datapoints,)
A list or 1D numpy array of the first sample
(e.g., the treatment group).
y : list or numpy array with shape (n_datapoints,)
A list or 1D numpy array of the second sample
(e.g., the control group).
func : custom function or str (default: 'x_mean != y_mean')
function to compute the statistic for the permutation test.
- If 'x_mean != y_mean', uses
`func=lambda x, y: np.abs(np.mean(x) - np.mean(y)))`
for a two-sided test.
- If 'x_mean > y_mean', uses
`func=lambda x, y: np.mean(x) - np.mean(y))`
for a one-sided test.
- If 'x_mean < y_mean', uses
`func=lambda x, y: np.mean(y) - np.mean(x))`
for a one-sided test.
method : 'approximate' or 'exact' (default: 'exact')
If 'exact' (default), all possible permutations are considered.
If 'approximate' the number of drawn samples is
given by `num_rounds`.
Note that 'exact' is typically not feasible unless the dataset
size is relatively small.
num_rounds : int (default: 1000)
The number of permutation samples if `method='approximate'`.
seed : int or None (default: None)
The random seed for generating permutation samples if
`method='approximate'`.
Returns
----------
p-value under the null hypothesis
Examples
-----------
For usage examples, please see
http://rasbt.github.io/mlxtend/user_guide/evaluate/permutation_test/
"""
if method not in ('approximate', 'exact'):
raise AttributeError('method must be "approximate"'
' or "exact", got %s' % method)
if isinstance(func, str):
if func not in (
'x_mean != y_mean', 'x_mean > y_mean', 'x_mean < y_mean'):
raise AttributeError('Provide a custom function'
' lambda x,y: ... or a string'
' in ("x_mean != y_mean", '
'"x_mean > y_mean", "x_mean < y_mean")')
elif func == 'x_mean != y_mean':
def func(x, y):
return np.abs(np.mean(x) - np.mean(y))
elif func == 'x_mean > y_mean':
def func(x, y):
return np.mean(x) - np.mean(y)
else:
def func(x, y):
return np.mean(y) - np.mean(x)
rng = np.random.RandomState(seed)
m, n = len(x), len(y)
combined = np.hstack((x, y))
at_least_as_extreme = 0.
reference_stat = func(x, y)
# Note that whether we compute the combinations or permutations
# does not affect the results, since the number of permutations
# n_A specific objects in A and n_B specific objects in B is the
# same for all combinations in x_1, ... x_{n_A} and
# x_{n_{A+1}}, ... x_{n_A + n_B}
# In other words, for any given number of combinations, we get
# n_A! x n_B! times as many permutations; hoewever, the computed
# value of those permutations that are merely re-arranged combinations
# does not change. Hence, the result, since we divide by the number of
# combinations or permutations is the same, the permutations simply have
# "n_A! x n_B!" as a scaling factor in the numerator and denominator
# and using combinations instead of permutations simply saves computational
# time
if method == 'exact':
for indices_x in combinations(range(m + n), m):
indices_y = [i for i in range(m + n) if i not in indices_x]
diff = func(combined[list(indices_x)], combined[indices_y])
if diff > reference_stat or np.isclose(diff, reference_stat):
at_least_as_extreme += 1.
num_rounds = factorial(m + n) / (factorial(m)*factorial(n))
else:
for i in range(num_rounds):
rng.shuffle(combined)
diff = func(combined[:m], combined[m:])
if diff > reference_stat or np.isclose(diff, reference_stat):
at_least_as_extreme += 1.
# To cover the actual experiment results
at_least_as_extreme += 1
num_rounds += 1
return at_least_as_extreme / num_rounds