20 lines
537 B
Python
20 lines
537 B
Python
|
import pytest
|
||
|
|
||
|
from sklearn.utils._mask import safe_mask
|
||
|
from sklearn.utils.fixes import CSR_CONTAINERS
|
||
|
from sklearn.utils.validation import check_random_state
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
|
||
|
def test_safe_mask(csr_container):
|
||
|
random_state = check_random_state(0)
|
||
|
X = random_state.rand(5, 4)
|
||
|
X_csr = csr_container(X)
|
||
|
mask = [False, False, True, True, True]
|
||
|
|
||
|
mask = safe_mask(X, mask)
|
||
|
assert X[mask].shape[0] == 3
|
||
|
|
||
|
mask = safe_mask(X_csr, mask)
|
||
|
assert X_csr[mask].shape[0] == 3
|