141 lines
4.9 KiB
Python
141 lines
4.9 KiB
Python
|
# Sebastian Raschka 2014-2020
|
||
|
# mlxtend Machine Learning Library Extensions
|
||
|
#
|
||
|
# Nonparametric Permutation Test
|
||
|
# Author: Sebastian Raschka <sebastianraschka.com>
|
||
|
#
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
import numpy as np
|
||
|
from itertools import combinations
|
||
|
from math import factorial
|
||
|
try:
|
||
|
from nose.tools import nottest
|
||
|
except ImportError:
|
||
|
# Use a no-op decorator if nose is not available
|
||
|
def nottest(f):
|
||
|
return f
|
||
|
|
||
|
|
||
|
# decorator to prevent nose to consider
|
||
|
# this as a unit test due to "test" in the name
|
||
|
@nottest
|
||
|
def permutation_test(x, y, func='x_mean != y_mean', method='exact',
|
||
|
num_rounds=1000, seed=None):
|
||
|
"""
|
||
|
Nonparametric permutation test
|
||
|
|
||
|
Parameters
|
||
|
-------------
|
||
|
x : list or numpy array with shape (n_datapoints,)
|
||
|
A list or 1D numpy array of the first sample
|
||
|
(e.g., the treatment group).
|
||
|
y : list or numpy array with shape (n_datapoints,)
|
||
|
A list or 1D numpy array of the second sample
|
||
|
(e.g., the control group).
|
||
|
func : custom function or str (default: 'x_mean != y_mean')
|
||
|
function to compute the statistic for the permutation test.
|
||
|
- If 'x_mean != y_mean', uses
|
||
|
`func=lambda x, y: np.abs(np.mean(x) - np.mean(y)))`
|
||
|
for a two-sided test.
|
||
|
- If 'x_mean > y_mean', uses
|
||
|
`func=lambda x, y: np.mean(x) - np.mean(y))`
|
||
|
for a one-sided test.
|
||
|
- If 'x_mean < y_mean', uses
|
||
|
`func=lambda x, y: np.mean(y) - np.mean(x))`
|
||
|
for a one-sided test.
|
||
|
method : 'approximate' or 'exact' (default: 'exact')
|
||
|
If 'exact' (default), all possible permutations are considered.
|
||
|
If 'approximate' the number of drawn samples is
|
||
|
given by `num_rounds`.
|
||
|
Note that 'exact' is typically not feasible unless the dataset
|
||
|
size is relatively small.
|
||
|
num_rounds : int (default: 1000)
|
||
|
The number of permutation samples if `method='approximate'`.
|
||
|
seed : int or None (default: None)
|
||
|
The random seed for generating permutation samples if
|
||
|
`method='approximate'`.
|
||
|
|
||
|
Returns
|
||
|
----------
|
||
|
p-value under the null hypothesis
|
||
|
|
||
|
Examples
|
||
|
-----------
|
||
|
For usage examples, please see
|
||
|
http://rasbt.github.io/mlxtend/user_guide/evaluate/permutation_test/
|
||
|
|
||
|
"""
|
||
|
|
||
|
if method not in ('approximate', 'exact'):
|
||
|
raise AttributeError('method must be "approximate"'
|
||
|
' or "exact", got %s' % method)
|
||
|
|
||
|
if isinstance(func, str):
|
||
|
|
||
|
if func not in (
|
||
|
'x_mean != y_mean', 'x_mean > y_mean', 'x_mean < y_mean'):
|
||
|
raise AttributeError('Provide a custom function'
|
||
|
' lambda x,y: ... or a string'
|
||
|
' in ("x_mean != y_mean", '
|
||
|
'"x_mean > y_mean", "x_mean < y_mean")')
|
||
|
|
||
|
elif func == 'x_mean != y_mean':
|
||
|
def func(x, y):
|
||
|
return np.abs(np.mean(x) - np.mean(y))
|
||
|
|
||
|
elif func == 'x_mean > y_mean':
|
||
|
def func(x, y):
|
||
|
return np.mean(x) - np.mean(y)
|
||
|
|
||
|
else:
|
||
|
def func(x, y):
|
||
|
return np.mean(y) - np.mean(x)
|
||
|
|
||
|
rng = np.random.RandomState(seed)
|
||
|
|
||
|
m, n = len(x), len(y)
|
||
|
combined = np.hstack((x, y))
|
||
|
|
||
|
at_least_as_extreme = 0.
|
||
|
reference_stat = func(x, y)
|
||
|
|
||
|
# Note that whether we compute the combinations or permutations
|
||
|
# does not affect the results, since the number of permutations
|
||
|
# n_A specific objects in A and n_B specific objects in B is the
|
||
|
# same for all combinations in x_1, ... x_{n_A} and
|
||
|
# x_{n_{A+1}}, ... x_{n_A + n_B}
|
||
|
# In other words, for any given number of combinations, we get
|
||
|
# n_A! x n_B! times as many permutations; hoewever, the computed
|
||
|
# value of those permutations that are merely re-arranged combinations
|
||
|
# does not change. Hence, the result, since we divide by the number of
|
||
|
# combinations or permutations is the same, the permutations simply have
|
||
|
# "n_A! x n_B!" as a scaling factor in the numerator and denominator
|
||
|
# and using combinations instead of permutations simply saves computational
|
||
|
# time
|
||
|
|
||
|
if method == 'exact':
|
||
|
for indices_x in combinations(range(m + n), m):
|
||
|
|
||
|
indices_y = [i for i in range(m + n) if i not in indices_x]
|
||
|
diff = func(combined[list(indices_x)], combined[indices_y])
|
||
|
|
||
|
if diff > reference_stat or np.isclose(diff, reference_stat):
|
||
|
at_least_as_extreme += 1.
|
||
|
|
||
|
num_rounds = factorial(m + n) / (factorial(m)*factorial(n))
|
||
|
|
||
|
else:
|
||
|
for i in range(num_rounds):
|
||
|
rng.shuffle(combined)
|
||
|
diff = func(combined[:m], combined[m:])
|
||
|
|
||
|
if diff > reference_stat or np.isclose(diff, reference_stat):
|
||
|
at_least_as_extreme += 1.
|
||
|
|
||
|
# To cover the actual experiment results
|
||
|
at_least_as_extreme += 1
|
||
|
num_rounds += 1
|
||
|
|
||
|
return at_least_as_extreme / num_rounds
|