# Authors: Gael Varoquaux # Justin Vincent # Lars Buitinck # License: BSD 3 clause import numpy as np import pytest from sklearn.utils._testing import assert_array_equal from sklearn.utils.fixes import _object_dtype_isnan, _smallest_admissible_index_dtype @pytest.mark.parametrize("dtype, val", ([object, 1], [object, "a"], [float, 1])) def test_object_dtype_isnan(dtype, val): X = np.array([[val, np.nan], [np.nan, val]], dtype=dtype) expected_mask = np.array([[False, True], [True, False]]) mask = _object_dtype_isnan(X) assert_array_equal(mask, expected_mask) @pytest.mark.parametrize( "params, expected_dtype", [ ({}, np.int32), # default behaviour ({"maxval": np.iinfo(np.int32).max}, np.int32), ({"maxval": np.iinfo(np.int32).max + 1}, np.int64), ], ) def test_smallest_admissible_index_dtype_max_val(params, expected_dtype): """Check the behaviour of `smallest_admissible_index_dtype` depending only on the `max_val` parameter. """ assert _smallest_admissible_index_dtype(**params) == expected_dtype @pytest.mark.parametrize( "params, expected_dtype", [ # Arrays dtype is int64 and thus should not be downcasted to int32 without # checking the content of providing maxval. ({"arrays": np.array([1, 2], dtype=np.int64)}, np.int64), # One of the array is int64 and should not be downcasted to int32 # for the same reasons. ( { "arrays": ( np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int64), ) }, np.int64, ), # Both arrays are already int32: we can just keep this dtype. ( { "arrays": ( np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int32), ) }, np.int32, ), # Arrays should be upcasted to at least int32 precision. ({"arrays": np.array([1, 2], dtype=np.int8)}, np.int32), # Check that `maxval` takes precedence over the arrays and thus upcast to # int64. ( { "arrays": np.array([1, 2], dtype=np.int32), "maxval": np.iinfo(np.int32).max + 1, }, np.int64, ), ], ) def test_smallest_admissible_index_dtype_without_checking_contents( params, expected_dtype ): """Check the behaviour of `smallest_admissible_index_dtype` using the passed arrays but without checking the contents of the arrays. """ assert _smallest_admissible_index_dtype(**params) == expected_dtype @pytest.mark.parametrize( "params, expected_dtype", [ # empty arrays should always be converted to int32 indices ( { "arrays": (np.array([], dtype=np.int64), np.array([], dtype=np.int64)), "check_contents": True, }, np.int32, ), # arrays respecting np.iinfo(np.int32).min < x < np.iinfo(np.int32).max should # be converted to int32, ( {"arrays": np.array([1], dtype=np.int64), "check_contents": True}, np.int32, ), # otherwise, it should be converted to int64. We need to create a uint32 # arrays to accommodate a value > np.iinfo(np.int32).max ( { "arrays": np.array([np.iinfo(np.int32).max + 1], dtype=np.uint32), "check_contents": True, }, np.int64, ), # maxval should take precedence over the arrays contents and thus upcast to # int64. ( { "arrays": np.array([1], dtype=np.int32), "check_contents": True, "maxval": np.iinfo(np.int32).max + 1, }, np.int64, ), # when maxval is small, but check_contents is True and the contents # require np.int64, we still require np.int64 indexing in the end. ( { "arrays": np.array([np.iinfo(np.int32).max + 1], dtype=np.uint32), "check_contents": True, "maxval": 1, }, np.int64, ), ], ) def test_smallest_admissible_index_dtype_by_checking_contents(params, expected_dtype): """Check the behaviour of `smallest_admissible_index_dtype` using the dtype of the arrays but as well the contents. """ assert _smallest_admissible_index_dtype(**params) == expected_dtype @pytest.mark.parametrize( "params, err_type, err_msg", [ ( {"maxval": np.iinfo(np.int64).max + 1}, ValueError, "is to large to be represented as np.int64", ), ( {"arrays": np.array([1, 2], dtype=np.float64)}, ValueError, "Array dtype float64 is not supported", ), ({"arrays": [1, 2]}, TypeError, "Arrays should be of type np.ndarray"), ], ) def test_smallest_admissible_index_dtype_error(params, err_type, err_msg): """Check that we raise the proper error message.""" with pytest.raises(err_type, match=err_msg): _smallest_admissible_index_dtype(**params)