diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index f166e94f49..a2301b5fb2 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -26,8 +26,18 @@ from ._copy_utils import _extract_impl, _nonzero_impl -def take(x, indices, /, *, axis=None, mode="clip"): - """take(x, indices, axis=None, mode="clip") +def _get_indexing_mode(name): + modes = {"wrap": 0, "clip": 1} + try: + return modes[name] + except KeyError: + raise ValueError( + "`mode` must be `wrap` or `clip`." "Got `{}`.".format(name) + ) + + +def take(x, indices, /, *, axis=None, mode="wrap"): + """take(x, indices, axis=None, mode="wrap") Takes elements from array along a given axis. @@ -42,15 +52,15 @@ def take(x, indices, /, *, axis=None, mode="clip"): Default: `None`. mode: How out-of-bounds indices will be handled. - "clip" - clamps indices to (-n <= i < n), then wraps + "wrap" - clamps indices to (-n <= i < n), then wraps negative indices. - "wrap" - wraps both negative and positive indices. - Default: `"clip"`. + "clip" - clips indices to (0 <= i < n) + Default: `"wrap"`. Returns: out: usm_ndarray Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:] - filled with elements . + filled with elements from x. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError( @@ -80,11 +90,7 @@ def take(x, indices, /, *, axis=None, mode="clip"): [x.usm_type, indices.usm_type] ) - modes = {"clip": 0, "wrap": 1} - try: - mode = modes[mode] - except KeyError: - raise ValueError("`mode` must be `clip` or `wrap`.") + mode = _get_indexing_mode(mode) x_ndim = x.ndim if axis is None: @@ -114,8 +120,8 @@ def take(x, indices, /, *, axis=None, mode="clip"): return res -def put(x, indices, vals, /, *, axis=None, mode="clip"): - """put(x, indices, vals, axis=None, mode="clip") +def put(x, indices, vals, /, *, axis=None, mode="wrap"): + """put(x, indices, vals, axis=None, mode="wrap") Puts values of an array into another array along a given axis. @@ -134,10 +140,10 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"): Default: `None`. mode: How out-of-bounds indices will be handled. - "clip" - clamps indices to (-axis_size <= i < axis_size), - then wraps negative indices. - "wrap" - wraps both negative and positive indices. - Default: `"clip"`. + "wrap" - clamps indices to (-n <= i < n), then wraps + negative indices. + "clip" - clips indices to (0 <= i < n) + Default: `"wrap"`. """ if not isinstance(x, dpt.usm_ndarray): raise TypeError( @@ -175,11 +181,8 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"): if exec_q is None: raise dpctl.utils.ExecutionPlacementError vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) - modes = {"clip": 0, "wrap": 1} - try: - mode = modes[mode] - except KeyError: - raise ValueError("`mode` must be `clip` or `wrap`.") + + mode = _get_indexing_mode(mode) x_ndim = x.ndim if axis is None: diff --git a/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp b/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp index 07b0a4e9b8..40f457ce15 100644 --- a/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp @@ -46,10 +46,10 @@ namespace py = pybind11; template class take_kernel; template class put_kernel; -class ClipIndex +class WrapIndex { public: - ClipIndex() = default; + WrapIndex() = default; void operator()(py::ssize_t max_item, py::ssize_t &ind) const { @@ -60,16 +60,15 @@ class ClipIndex } }; -class WrapIndex +class ClipIndex { public: - WrapIndex() = default; + ClipIndex() = default; void operator()(py::ssize_t max_item, py::ssize_t &ind) const { max_item = std::max(max_item, 1); - ind = (ind < 0) ? (ind + max_item * ((-ind / max_item) + 1)) % max_item - : ind % max_item; + ind = std::clamp(ind, 0, max_item - 1); return; } }; diff --git a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp index cb148c7df3..c2589836fd 100644 --- a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -40,8 +40,8 @@ #include "integer_advanced_indexing.hpp" #define INDEXING_MODES 2 -#define CLIP_MODE 0 -#define WRAP_MODE 1 +#define WRAP_MODE 0 +#define CLIP_MODE 1 namespace dpctl { @@ -252,8 +252,8 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src, throw py::value_error("Axis cannot be negative."); } - if (mode != 0 && mode != 1) { - throw py::value_error("Mode must be 0 or 1."); + if (mode != 0 && mode != 1 && mode != 2) { + throw py::value_error("Mode must be 0, 1, or 2."); } const dpctl::tensor::usm_ndarray ind_rep = ind[0]; @@ -575,8 +575,8 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst, throw py::value_error("Axis cannot be negative."); } - if (mode != 0 && mode != 1) { - throw py::value_error("Mode must be 0 or 1."); + if (mode != 0 && mode != 1 && mode != 2) { + throw py::value_error("Mode must be 0, 1, or 2."); } if (!dst.is_writable()) { diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 5d160bf0f1..a57ea83cea 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -20,6 +20,7 @@ from helper import get_queue_or_skip, skip_if_dtype_not_supported from numpy.testing import assert_array_equal +import dpctl import dpctl.tensor as dpt from dpctl.utils import ExecutionPlacementError @@ -895,20 +896,21 @@ def test_integer_indexing_modes(): q = get_queue_or_skip() x = dpt.arange(5, sycl_queue=q) + x_np = dpt.asnumpy(x) + + # wrapping negative indices + ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q) - # wrapping - ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q) res = dpt.take(x, ind, mode="wrap") - expected_arr = np.take(dpt.asnumpy(x), dpt.asnumpy(ind), mode="wrap") + expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="raise") assert (dpt.asnumpy(res) == expected_arr).all() - # clipping to -n<=i