diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 2a2afd60a4..77f102fa56 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -51,6 +51,7 @@ int16, int32, int64, + isdtype, uint8, uint16, uint32, @@ -125,6 +126,7 @@ "tril", "triu", "dtype", + "isdtype", "bool", "int8", "uint8", diff --git a/dpctl/tensor/_data_types.py b/dpctl/tensor/_data_types.py index c97afe37be..70129363de 100644 --- a/dpctl/tensor/_data_types.py +++ b/dpctl/tensor/_data_types.py @@ -31,8 +31,52 @@ complex64 = dtype("complex64") complex128 = dtype("complex128") + +def isdtype(dtype_, kind): + """isdtype(dtype, kind) + + Returns a boolean indicating whether a provided `dtype` is + of a specified data type `kind`. + + See [array API](array_api) for more information. + + [array_api]: https://data-apis.org/array-api/latest/ + """ + + if not isinstance(dtype_, dtype): + raise TypeError("Expected instance of `dpt.dtype`, got {dtype_}") + + if isinstance(kind, dtype): + return dtype_ == kind + + elif isinstance(kind, str): + if kind == "bool": + return dtype_ == dtype("bool") + elif kind == "signed integer": + return dtype_.kind == "i" + elif kind == "unsigned integer": + return dtype_.kind == "u" + elif kind == "integral": + return dtype_.kind in "iu" + elif kind == "real floating": + return dtype_.kind == "f" + elif kind == "complex floating": + return dtype_.kind == "c" + elif kind == "numeric": + return dtype_.kind in "iufc" + else: + raise ValueError(f"Unrecognized data type kind: {kind}") + + elif isinstance(kind, tuple): + return any(isdtype(dtype_, k) for k in kind) + + else: + raise TypeError(f"Unsupported data type kind: {kind}") + + __all__ = [ "dtype", + "isdtype", "bool", "int8", "uint8", diff --git a/dpctl/tests/test_tensor_dtype_routines.py b/dpctl/tests/test_tensor_dtype_routines.py new file mode 100644 index 0000000000..acb1bb6d8b --- /dev/null +++ b/dpctl/tests/test_tensor_dtype_routines.py @@ -0,0 +1,129 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +import dpctl.tensor as dpt + +list_dtypes = [ + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "complex64", + "complex128", +] + + +dtype_categories = { + "bool": ["bool"], + "signed integer": ["int8", "int16", "int32", "int64"], + "unsigned integer": ["uint8", "uint16", "uint32", "uint64"], + "integral": [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + ], + "real floating": ["float16", "float32", "float64"], + "complex floating": ["complex64", "complex128"], + "numeric": [d for d in list_dtypes if d != "bool"], +} + + +@pytest.mark.parametrize("kind_str", dtype_categories.keys()) +@pytest.mark.parametrize("dtype_str", list_dtypes) +def test_isdtype_kind_str(dtype_str, kind_str): + dt = dpt.dtype(dtype_str) + is_in_kind = dpt.isdtype(dt, kind_str) + expected = dtype_str in dtype_categories[kind_str] + assert is_in_kind == expected + + +@pytest.mark.parametrize("dtype_str", list_dtypes) +def test_isdtype_kind_tuple(dtype_str): + dt = dpt.dtype(dtype_str) + if dtype_str.startswith("bool"): + assert dpt.isdtype(dt, ("real floating", "bool")) + assert not dpt.isdtype( + dt, ("integral", "real floating", "complex floating") + ) + elif dtype_str.startswith("int"): + assert dpt.isdtype(dt, ("real floating", "signed integer")) + assert not dpt.isdtype( + dt, ("bool", "unsigned integer", "real floating") + ) + elif dtype_str.startswith("uint"): + assert dpt.isdtype(dt, ("bool", "unsigned integer")) + assert not dpt.isdtype(dt, ("real floating", "complex floating")) + elif dtype_str.startswith("float"): + assert dpt.isdtype(dt, ("complex floating", "real floating")) + assert not dpt.isdtype(dt, ("integral", "complex floating", "bool")) + else: + assert dpt.isdtype(dt, ("integral", "complex floating")) + assert not dpt.isdtype(dt, ("bool", "integral", "real floating")) + + +@pytest.mark.parametrize("dtype_str", list_dtypes) +def test_isdtype_kind_tuple_dtypes(dtype_str): + dt = dpt.dtype(dtype_str) + if dtype_str.startswith("bool"): + assert dpt.isdtype(dt, (dpt.int32, dpt.bool)) + assert not dpt.isdtype(dt, (dpt.int16, dpt.uint32, dpt.float64)) + + elif dtype_str.startswith("int"): + assert dpt.isdtype(dt, (dpt.int8, dpt.int16, dpt.int32, dpt.int64)) + assert not dpt.isdtype(dt, (dpt.bool, dpt.float32, dpt.complex64)) + + elif dtype_str.startswith("uint"): + assert dpt.isdtype(dt, (dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64)) + assert not dpt.isdtype(dt, (dpt.bool, dpt.int32, dpt.float32)) + + elif dtype_str.startswith("float"): + assert dpt.isdtype(dt, (dpt.float16, dpt.float32, dpt.float64)) + assert not dpt.isdtype(dt, (dpt.bool, dpt.complex64, dpt.int8)) + + else: + assert dpt.isdtype(dt, (dpt.complex64, dpt.complex128)) + assert not dpt.isdtype(dt, (dpt.bool, dpt.uint64, dpt.int8)) + + +@pytest.mark.parametrize( + "kind", + [ + [dpt.int32, dpt.bool], + "f4", + float, + 123, + "complex", + ], +) +def test_isdtype_invalid_kind(kind): + with pytest.raises((TypeError, ValueError)): + dpt.isdtype(dpt.int32, kind)