From eccd0263f25afefd768a0471b4b0a2f71bb478f5 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 23 Mar 2023 09:08:22 -0500 Subject: [PATCH 1/2] Closes gh-1135 Handle x[True] and x[False] as NumPy does, even though the behavior may be undocumented. NumPy treats True as None (insert axis with size 1), and treats False as None followed by empty slicing (insert axis with size 0). Changed the logic of _basic_slice_meta utility function to correctly handle boolean scalars (surprisingly, `insinstance(True, int)` evaluates to `True`). 0d arrays are handled by Python scalars. Introduced _is_integral and _is_boolean utilty functions and used them in `_basic_slice_meta` utility. --- dpctl/tensor/_slicing.pxi | 64 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/_slicing.pxi b/dpctl/tensor/_slicing.pxi index 10b5c58395..361dd906c3 100644 --- a/dpctl/tensor/_slicing.pxi +++ b/dpctl/tensor/_slicing.pxi @@ -1,6 +1,6 @@ # Data Parallel Control (dpctl) # -# Copyright 2020-2022 Intel Corporation +# Copyright 2020-2023 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,11 @@ # limitations under the License. import numbers +from cpython.buffer cimport PyObject_CheckBuffer + + +cdef bint _is_buffer(object o): + return PyObject_CheckBuffer(o) cdef Py_ssize_t _slice_len( @@ -36,14 +41,23 @@ cdef Py_ssize_t _slice_len( cdef bint _is_integral(object x) except *: """Gives True if x is an integral slice spec""" - if isinstance(x, (int, numbers.Integral)): - return True if isinstance(x, usm_ndarray): if x.ndim > 0: return False if x.dtype.kind not in "ui": return False return True + if isinstance(x, bool): + return False + if isinstance(x, int): + return True + if _is_buffer(x): + mbuf = memoryview(x) + if mbuf.ndim == 0: + f = mbuf.format + return f in "bBhHiIlLqQ" + else: + return False if callable(getattr(x, "__index__", None)): try: x.__index__() @@ -53,6 +67,34 @@ cdef bint _is_integral(object x) except *: return False +cdef bint _is_boolean(object x) except *: + """Gives True if x is an integral slice spec""" + if isinstance(x, usm_ndarray): + if x.ndim > 0: + return False + if x.dtype.kind not in "b": + return False + return True + if isinstance(x, bool): + return True + if isinstance(x, int): + return False + if _is_buffer(x): + mbuf = memoryview(x) + if mbuf.ndim == 0: + f = mbuf.format + return f in "?" + else: + return False + if callable(getattr(x, "__bool__", None)): + try: + x.__bool__() + except (TypeError, ValueError): + return False + return True + return False + + def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): """ Give basic slicing index `ind` and array layout information produce @@ -82,6 +124,11 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): _no_advanced_ind, _no_advanced_pos ) + elif _is_boolean(ind): + if ind: + return ((1,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos) + else: + return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos) elif _is_integral(ind): ind = ind.__index__() if 0 <= ind < shape[0]: @@ -117,6 +164,10 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): axes_referenced += 1 if array_streak_started: array_streak_interrupted = True + elif _is_boolean(i): + newaxis_count += 1 + if array_streak_started: + array_streak_interrupted = True elif _is_integral(i): explicit_index += 1 axes_referenced += 1 @@ -133,9 +184,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): "separated by basic slicing specs." ) dt_k = i.dtype.kind - if dt_k == "b": + if dt_k == "b" and i.ndim > 0: axes_referenced += i.ndim - elif dt_k in "ui": + elif dt_k in "ui" and i.ndim > 0: axes_referenced += 1 else: raise IndexError( @@ -186,6 +237,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): if sh_i == 0: is_empty = True k = k_new + elif _is_boolean(ind_i): + new_shape.append(1 if ind_i else 0) + new_strides.append(0) elif _is_integral(ind_i): ind_i = ind_i.__index__() if 0 <= ind_i < shape[k]: From 8c48886370017bd2d6c042bcf4d73dc7de5e020a Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 23 Mar 2023 09:21:05 -0500 Subject: [PATCH 2/2] Added tests for True/False indexing of usm_ndarray --- dpctl/tests/test_usm_ndarray_indexing.py | 26 ++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index a57ea83cea..41688075e0 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -455,6 +455,32 @@ def test_integer_strided_indexing(): assert (dpt.asnumpy(y) == dpt.asnumpy(yc)).all() +def test_TrueFalse_indexing(): + get_queue_or_skip() + n0, n1 = 2, 3 + x = dpt.ones((n0, n1)) + for ind in [True, dpt.asarray(True)]: + y1 = x[ind] + assert y1.shape == (1, n0, n1) + assert y1._pointer == x._pointer + y2 = x[:, ind] + assert y2.shape == (n0, 1, n1) + assert y2._pointer == x._pointer + y3 = x[..., ind] + assert y3.shape == (n0, n1, 1) + assert y3._pointer == x._pointer + for ind in [False, dpt.asarray(False)]: + y1 = x[ind] + assert y1.shape == (0, n0, n1) + assert y1._pointer == x._pointer + y2 = x[:, ind] + assert y2.shape == (n0, 0, n1) + assert y2._pointer == x._pointer + y3 = x[..., ind] + assert y3.shape == (n0, n1, 0) + assert y3._pointer == x._pointer + + @pytest.mark.parametrize( "data_dt", _all_dtypes,