Skip to content
Prev Previous commit
Next Next commit
Update scalar conversion tests for non-0d arrays
  • Loading branch information
vlad-perevezentsev committed Dec 22, 2025
commit 33b2c8de1242113e4d8bd7a00040d044bb42136c
75 changes: 51 additions & 24 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import pytest
from numpy.testing import assert_raises_regex

import dpctl
import dpctl.memory as dpm
Expand Down Expand Up @@ -282,34 +283,60 @@ def test_properties(dt):
V.mT


@pytest.mark.parametrize("func", [bool, float, int, complex])
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
def test_copy_scalar_with_func(func, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert func(X) == func(Y)
class TestCopyScalar:
def test_copy_bool_scalar_with_func(self, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert bool(X) == bool(Y)

@pytest.mark.parametrize("func", [float, int, complex])
def test_copy_numeric_scalar_with_func(self, func, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
# Non-0D numeric arrays must not be convertible to Python scalars
if len(shape) != 0:
assert_raises_regex(TypeError, "only 0-dimensional arrays", func, X)
else:
# 0D arrays are allowed to convert
assert func(X) == func(Y)

def test_copy_bool_scalar_with_method(self, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
Comment thread
vlad-perevezentsev marked this conversation as resolved.
Outdated
assert getattr(X, "__bool__")() == getattr(Y, "__bool__")()

@pytest.mark.parametrize(
"method", ["__bool__", "__float__", "__int__", "__complex__"]
)
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
def test_copy_scalar_with_method(method, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert getattr(X, method)() == getattr(Y, method)()
@pytest.mark.parametrize("method", ["__float__", "__int__", "__complex__"])
def test_copy_numeric_scalar_with_method(self, method, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
if len(shape) != 0:
assert_raises_regex(
TypeError, "only 0-dimensional arrays", getattr(X, method)
)
else:
assert getattr(X, method)() == getattr(Y, method)()


@pytest.mark.parametrize("func", [bool, float, int, complex])
Expand Down
Loading