24 lines
602 B
Python
24 lines
602 B
Python
![]() |
import pytest
|
||
|
|
||
|
from numpy import array_api as xp
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"obj, axis, expected",
|
||
|
[
|
||
|
([0, 0], -1, [0, 1]),
|
||
|
([0, 1, 0], -1, [1, 0, 2]),
|
||
|
([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]),
|
||
|
([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]),
|
||
|
],
|
||
|
)
|
||
|
def test_stable_desc_argsort(obj, axis, expected):
|
||
|
"""
|
||
|
Indices respect relative order of a descending stable-sort
|
||
|
|
||
|
See https://github.com/numpy/numpy/issues/20778
|
||
|
"""
|
||
|
x = xp.asarray(obj)
|
||
|
out = xp.argsort(x, axis=axis, stable=True, descending=True)
|
||
|
assert xp.all(out == xp.asarray(expected))
|