210 lines
7.4 KiB
Python
210 lines
7.4 KiB
Python
from itertools import permutations
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from pandas._libs.interval import IntervalTree
|
|
from pandas.compat import IS64
|
|
|
|
import pandas._testing as tm
|
|
|
|
|
|
def skipif_32bit(param):
|
|
"""
|
|
Skip parameters in a parametrize on 32bit systems. Specifically used
|
|
here to skip leaf_size parameters related to GH 23440.
|
|
"""
|
|
marks = pytest.mark.skipif(not IS64, reason="GH 23440: int type mismatch on 32bit")
|
|
return pytest.param(param, marks=marks)
|
|
|
|
|
|
@pytest.fixture(params=["int64", "float64", "uint64"])
|
|
def dtype(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(params=[skipif_32bit(1), skipif_32bit(2), 10])
|
|
def leaf_size(request):
|
|
"""
|
|
Fixture to specify IntervalTree leaf_size parameter; to be used with the
|
|
tree fixture.
|
|
"""
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(
|
|
params=[
|
|
np.arange(5, dtype="int64"),
|
|
np.arange(5, dtype="uint64"),
|
|
np.arange(5, dtype="float64"),
|
|
np.array([0, 1, 2, 3, 4, np.nan], dtype="float64"),
|
|
]
|
|
)
|
|
def tree(request, leaf_size):
|
|
left = request.param
|
|
return IntervalTree(left, left + 2, leaf_size=leaf_size)
|
|
|
|
|
|
class TestIntervalTree:
|
|
def test_get_indexer(self, tree):
|
|
result = tree.get_indexer(np.array([1.0, 5.5, 6.5]))
|
|
expected = np.array([0, 4, -1], dtype="intp")
|
|
tm.assert_numpy_array_equal(result, expected)
|
|
|
|
with pytest.raises(
|
|
KeyError, match="'indexer does not intersect a unique set of intervals'"
|
|
):
|
|
tree.get_indexer(np.array([3.0]))
|
|
|
|
@pytest.mark.parametrize(
|
|
"dtype, target_value, target_dtype",
|
|
[("int64", 2**63 + 1, "uint64"), ("uint64", -1, "int64")],
|
|
)
|
|
def test_get_indexer_overflow(self, dtype, target_value, target_dtype):
|
|
left, right = np.array([0, 1], dtype=dtype), np.array([1, 2], dtype=dtype)
|
|
tree = IntervalTree(left, right)
|
|
|
|
result = tree.get_indexer(np.array([target_value], dtype=target_dtype))
|
|
expected = np.array([-1], dtype="intp")
|
|
tm.assert_numpy_array_equal(result, expected)
|
|
|
|
def test_get_indexer_non_unique(self, tree):
|
|
indexer, missing = tree.get_indexer_non_unique(np.array([1.0, 2.0, 6.5]))
|
|
|
|
result = indexer[:1]
|
|
expected = np.array([0], dtype="intp")
|
|
tm.assert_numpy_array_equal(result, expected)
|
|
|
|
result = np.sort(indexer[1:3])
|
|
expected = np.array([0, 1], dtype="intp")
|
|
tm.assert_numpy_array_equal(result, expected)
|
|
|
|
result = np.sort(indexer[3:])
|
|
expected = np.array([-1], dtype="intp")
|
|
tm.assert_numpy_array_equal(result, expected)
|
|
|
|
result = missing
|
|
expected = np.array([2], dtype="intp")
|
|
tm.assert_numpy_array_equal(result, expected)
|
|
|
|
@pytest.mark.parametrize(
|
|
"dtype, target_value, target_dtype",
|
|
[("int64", 2**63 + 1, "uint64"), ("uint64", -1, "int64")],
|
|
)
|
|
def test_get_indexer_non_unique_overflow(self, dtype, target_value, target_dtype):
|
|
left, right = np.array([0, 2], dtype=dtype), np.array([1, 3], dtype=dtype)
|
|
tree = IntervalTree(left, right)
|
|
target = np.array([target_value], dtype=target_dtype)
|
|
|
|
result_indexer, result_missing = tree.get_indexer_non_unique(target)
|
|
expected_indexer = np.array([-1], dtype="intp")
|
|
tm.assert_numpy_array_equal(result_indexer, expected_indexer)
|
|
|
|
expected_missing = np.array([0], dtype="intp")
|
|
tm.assert_numpy_array_equal(result_missing, expected_missing)
|
|
|
|
def test_duplicates(self, dtype):
|
|
left = np.array([0, 0, 0], dtype=dtype)
|
|
tree = IntervalTree(left, left + 1)
|
|
|
|
with pytest.raises(
|
|
KeyError, match="'indexer does not intersect a unique set of intervals'"
|
|
):
|
|
tree.get_indexer(np.array([0.5]))
|
|
|
|
indexer, missing = tree.get_indexer_non_unique(np.array([0.5]))
|
|
result = np.sort(indexer)
|
|
expected = np.array([0, 1, 2], dtype="intp")
|
|
tm.assert_numpy_array_equal(result, expected)
|
|
|
|
result = missing
|
|
expected = np.array([], dtype="intp")
|
|
tm.assert_numpy_array_equal(result, expected)
|
|
|
|
@pytest.mark.parametrize(
|
|
"leaf_size", [skipif_32bit(1), skipif_32bit(10), skipif_32bit(100), 10000]
|
|
)
|
|
def test_get_indexer_closed(self, closed, leaf_size):
|
|
x = np.arange(1000, dtype="float64")
|
|
found = x.astype("intp")
|
|
not_found = (-1 * np.ones(1000)).astype("intp")
|
|
|
|
tree = IntervalTree(x, x + 0.5, closed=closed, leaf_size=leaf_size)
|
|
tm.assert_numpy_array_equal(found, tree.get_indexer(x + 0.25))
|
|
|
|
expected = found if tree.closed_left else not_found
|
|
tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.0))
|
|
|
|
expected = found if tree.closed_right else not_found
|
|
tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.5))
|
|
|
|
@pytest.mark.parametrize(
|
|
"left, right, expected",
|
|
[
|
|
(np.array([0, 1, 4], dtype="int64"), np.array([2, 3, 5]), True),
|
|
(np.array([0, 1, 2], dtype="int64"), np.array([5, 4, 3]), True),
|
|
(np.array([0, 1, np.nan]), np.array([5, 4, np.nan]), True),
|
|
(np.array([0, 2, 4], dtype="int64"), np.array([1, 3, 5]), False),
|
|
(np.array([0, 2, np.nan]), np.array([1, 3, np.nan]), False),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("order", (list(x) for x in permutations(range(3))))
|
|
def test_is_overlapping(self, closed, order, left, right, expected):
|
|
# GH 23309
|
|
tree = IntervalTree(left[order], right[order], closed=closed)
|
|
result = tree.is_overlapping
|
|
assert result is expected
|
|
|
|
@pytest.mark.parametrize("order", (list(x) for x in permutations(range(3))))
|
|
def test_is_overlapping_endpoints(self, closed, order):
|
|
"""shared endpoints are marked as overlapping"""
|
|
# GH 23309
|
|
left, right = np.arange(3, dtype="int64"), np.arange(1, 4)
|
|
tree = IntervalTree(left[order], right[order], closed=closed)
|
|
result = tree.is_overlapping
|
|
expected = closed == "both"
|
|
assert result is expected
|
|
|
|
@pytest.mark.parametrize(
|
|
"left, right",
|
|
[
|
|
(np.array([], dtype="int64"), np.array([], dtype="int64")),
|
|
(np.array([0], dtype="int64"), np.array([1], dtype="int64")),
|
|
(np.array([np.nan]), np.array([np.nan])),
|
|
(np.array([np.nan] * 3), np.array([np.nan] * 3)),
|
|
],
|
|
)
|
|
def test_is_overlapping_trivial(self, closed, left, right):
|
|
# GH 23309
|
|
tree = IntervalTree(left, right, closed=closed)
|
|
assert tree.is_overlapping is False
|
|
|
|
@pytest.mark.skipif(not IS64, reason="GH 23440")
|
|
def test_construction_overflow(self):
|
|
# GH 25485
|
|
left, right = np.arange(101, dtype="int64"), [np.iinfo(np.int64).max] * 101
|
|
tree = IntervalTree(left, right)
|
|
|
|
# pivot should be average of left/right medians
|
|
result = tree.root.pivot
|
|
expected = (50 + np.iinfo(np.int64).max) / 2
|
|
assert result == expected
|
|
|
|
@pytest.mark.xfail(not IS64, reason="GH 23440")
|
|
@pytest.mark.parametrize(
|
|
"left, right, expected",
|
|
[
|
|
([-np.inf, 1.0], [1.0, 2.0], 0.0),
|
|
([-np.inf, -2.0], [-2.0, -1.0], -2.0),
|
|
([-2.0, -1.0], [-1.0, np.inf], 0.0),
|
|
([1.0, 2.0], [2.0, np.inf], 2.0),
|
|
],
|
|
)
|
|
def test_inf_bound_infinite_recursion(self, left, right, expected):
|
|
# GH 46658
|
|
|
|
tree = IntervalTree(left * 101, right * 101)
|
|
|
|
result = tree.root.pivot
|
|
assert result == expected
|