Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add information displayed on failure, renamed variables
Add check of computed against expected indices
  • Loading branch information
oleksandr-pavlyk committed Dec 27, 2024
commit 1bb83bc33ad73b55605536d4d4f7f1e185575bf0
157 changes: 131 additions & 26 deletions dpctl/tests/test_usm_ndarray_top_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,38 @@
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported


def _expected_largest_inds(inp, n, shift, k):
"Computed expected top_k indices for mode='largest'"
assert k < n
ones_start_id = shift % (2 * n)

alloc_dev = inp.device

if ones_start_id < n:
expected_inds = dpt.arange(
ones_start_id, ones_start_id + k, dtype="i8", device=alloc_dev
)
else:
# wrap-around
ones_end_id = (ones_start_id + n) % (2 * n)
if ones_end_id >= k:
expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev)
else:
expected_inds = dpt.concat(
(
dpt.arange(ones_end_id, dtype="i8", device=alloc_dev),
dpt.arange(
ones_start_id,
ones_start_id + k - ones_end_id,
dtype="i8",
device=alloc_dev,
),
)
)

return expected_inds


@pytest.mark.parametrize(
"dtype",
[
Expand All @@ -38,23 +70,57 @@
"c16",
],
)
@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193])
def test_topk_1d_largest(dtype, n):
@pytest.mark.parametrize("n", [33, 43, 255, 511, 1021, 8193])
def test_top_k_1d_largest(dtype, n):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

shift, k = 734, 5
o = dpt.ones(n, dtype=dtype)
z = dpt.zeros(n, dtype=dtype)
zo = dpt.concat((o, z))
inp = dpt.roll(zo, 734)
k = 5
oz = dpt.concat((o, z))
inp = dpt.roll(oz, shift)

expected_inds = _expected_largest_inds(oz, n, shift, k)

s = dpt.top_k(inp, k, mode="largest")
assert s.values.shape == (k,)
assert s.values.dtype == inp.dtype
assert s.indices.shape == (k,)
assert dpt.all(s.values == dpt.ones(k, dtype=dtype))
assert dpt.all(s.values == inp[s.indices])
assert dpt.all(s.indices == expected_inds)
assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values
assert dpt.all(s.values == inp[s.indices]), s.indices


def _expected_smallest_inds(inp, n, shift, k):
"Computed expected top_k indices for mode='smallest'"
assert k < n
zeros_start_id = (n + shift) % (2 * n)
zeros_end_id = (shift) % (2 * n)

alloc_dev = inp.device

if zeros_start_id < zeros_end_id:
expected_inds = dpt.arange(
zeros_start_id, zeros_start_id + k, dtype="i8", device=alloc_dev
)
else:
if zeros_end_id >= k:
expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev)
else:
expected_inds = dpt.concat(
(
dpt.arange(zeros_end_id, dtype="i8", device=alloc_dev),
dpt.arange(
zeros_start_id,
zeros_start_id + k - zeros_end_id,
dtype="i8",
device=alloc_dev,
),
)
)

return expected_inds


@pytest.mark.parametrize(
Expand All @@ -75,41 +141,80 @@ def test_topk_1d_largest(dtype, n):
"c16",
],
)
@pytest.mark.parametrize("n", [33, 255, 257, 513, 1021, 8193])
def test_topk_1d_smallest(dtype, n):
@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193])
def test_top_k_1d_smallest(dtype, n):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

shift, k = 734, 5
o = dpt.ones(n, dtype=dtype)
z = dpt.zeros(n, dtype=dtype)
zo = dpt.concat((o, z))
inp = dpt.roll(zo, 734)
k = 5
oz = dpt.concat((o, z))
inp = dpt.roll(oz, shift)

expected_inds = _expected_smallest_inds(oz, n, shift, k)

s = dpt.top_k(inp, k, mode="smallest")
assert s.values.shape == (k,)
assert s.values.dtype == inp.dtype
assert s.indices.shape == (k,)
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype))
assert dpt.all(s.values == inp[s.indices])
assert dpt.all(s.indices == expected_inds)
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values
assert dpt.all(s.values == inp[s.indices]), s.indices


# triage failing top k radix implementation on CPU
# replicates from Python behavior of radix sort topk implementation
@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193])
def test_topk_largest_1d_radix_i1_255(n):
@pytest.mark.parametrize(
"n",
[
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
61,
137,
255,
511,
1021,
8193,
],
)
def test_top_k_largest_1d_radix_i1(n):
get_queue_or_skip()
dt = "i1"

shift, k = 734, 5
o = dpt.ones(n, dtype=dt)
z = dpt.zeros(n, dtype=dt)
zo = dpt.concat((o, z))
inp = dpt.roll(zo, 734)
k = 5

sorted = dpt.copy(dpt.sort(inp, descending=True, kind="radixsort")[:k])
argsorted = dpt.copy(
dpt.argsort(inp, descending=True, kind="radixsort")[:k]
)
assert dpt.all(sorted == dpt.ones(k, dtype=dt))
assert dpt.all(sorted == inp[argsorted])
oz = dpt.concat((o, z))
inp = dpt.roll(oz, shift)

expected_inds = _expected_largest_inds(oz, n, shift, k)

sorted_v = dpt.sort(inp, descending=True, kind="radixsort")
argsorted = dpt.argsort(inp, descending=True, kind="radixsort")

assert dpt.all(sorted_v == inp[argsorted])

topk_vals = dpt.copy(sorted_v[:k])
topk_inds = dpt.copy(argsorted[:k])

assert dpt.all(topk_vals == dpt.ones(k, dtype=dt))
assert dpt.all(topk_inds == expected_inds)

assert dpt.all(topk_vals == inp[topk_inds]), topk_inds