Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ cdef class usm_ndarray:
mem_ptr = <char *>(<size_t> ary_iface['data'][0])
ary_ptr = <char *>(<size_t> self.data_)
ro_flag = False if (self.flags_ & USM_ARRAY_WRITEABLE) else True
ary_iface['data'] = (<size_t> ary_ptr, ro_flag)
ary_iface['data'] = (<size_t> mem_ptr, ro_flag)
ary_iface['shape'] = self.shape
if (self.strides_):
ary_iface['strides'] = _make_int_tuple(self.nd_, self.strides_)
Expand All @@ -335,7 +335,7 @@ cdef class usm_ndarray:
"""
Gives the number of indices needed to address elements of this array.
"""
return int(self.nd_)
return self.nd_

@property
def usm_data(self):
Expand Down
71 changes: 69 additions & 2 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
import pytest

import dpctl

# import dpctl.memory as dpmem
import dpctl.memory as dpm
import dpctl.tensor as dpt
from dpctl.tensor._usmarray import Device

Expand Down Expand Up @@ -224,3 +223,71 @@ def test_slice_constructor_3d():
assert np.array_equal(
_to_numpy(Xusm[ind]), Xh[ind]
), "Failed for {}".format(ind)


@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
def test_slice_suai(usm_type):
Xh = np.arange(0, 10, dtype="u1")
default_device = dpctl.select_default_device()
Xusm = _from_numpy(Xh, device=default_device, usm_type=usm_type)
for ind in [slice(2, 3, None), slice(5, 7, None), slice(3, 9, None)]:
assert np.array_equal(
dpm.as_usm_memory(Xusm[ind]).copy_to_host(), Xh[ind]
), "Failed for {}".format(ind)


def test_slicing_basic():
Xusm = dpt.usm_ndarray((10, 5), dtype="c16")
Xusm[None]
Xusm[...]
Xusm[8]
Xusm[-3]
with pytest.raises(IndexError):
Xusm[..., ...]
with pytest.raises(IndexError):
Xusm[1, 1, :, 1]
Xusm[:, -4]
with pytest.raises(IndexError):
Xusm[:, -128]
with pytest.raises(TypeError):
Xusm[{1, 2, 3, 4, 5, 6, 7}]


def test_ctor_invalid_shape():
with pytest.raises(TypeError):
dpt.usm_ndarray(dict())


def test_ctor_invalid_order():
with pytest.raises(ValueError):
dpt.usm_ndarray((5, 5, 3), order="Z")


def test_ctor_buffer_kwarg():
dpt.usm_ndarray(10, buffer=b"device")
with pytest.raises(ValueError):
dpt.usm_ndarray(10, buffer="invalid_param")
Xusm = dpt.usm_ndarray((10, 5), dtype="c16")
X2 = dpt.usm_ndarray(Xusm.shape, buffer=Xusm, dtype=Xusm.dtype)
assert np.array_equal(
Xusm.usm_data.copy_to_host(), X2.usm_data.copy_to_host()
)
with pytest.raises(ValueError):
dpt.usm_ndarray(10, buffer=dict())


def test_usm_ndarray_props():
Xusm = dpt.usm_ndarray((10, 5), dtype="c16", order="F")
Xusm.ndim
repr(Xusm)
Xusm.flags
Xusm.__sycl_usm_array_interface__
Xusm.device
Xusm.strides
Xusm.real
Xusm.imag
try:
dpctl.SyclQueue("cpu")
except dpctl.SyclQueueCreationError:
pytest.skip("Sycl device CPU was not detected")
Xusm.to_device("cpu")