diff --git a/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp b/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp index 822cceb3d7..8b7c4fe358 100644 --- a/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp +++ b/examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp @@ -6,16 +6,6 @@ namespace py = pybind11; -/* DPCTL C-API for usm_ndarray - UsmNDArray_GetData - UsmNDArray_GetNDim - UsmNDArray_GetShape - UsmNDArray_GetStrides - UsmNDArray_GetTypenum - UsmNDArray_GetFlags - UsmNDArray_GetQueueRef - */ - sycl::event keep_args_alive(sycl::queue q, py::object o1, py::object o2, @@ -47,51 +37,29 @@ sycl::event keep_args_alive(sycl::queue q, std::pair gemv(sycl::queue q, - py::object matrix, - py::object vector, - py::object result, + dpctl::tensor::usm_ndarray matrix, + dpctl::tensor::usm_ndarray vector, + dpctl::tensor::usm_ndarray result, const std::vector &depends = {}) { - PyObject *m_src = matrix.ptr(); - if (!PyObject_TypeCheck(m_src, &PyUSMArrayType)) { - throw std::runtime_error("Matrix is not a dpctl.tensor.usm_ndarray"); - } - - PyObject *v_src = vector.ptr(); - if (!PyObject_TypeCheck(v_src, &PyUSMArrayType)) { - throw std::runtime_error("Vector is not a dpctl.tensor.usm_ndarray"); - } - - PyObject *r_src = result.ptr(); - if (!PyObject_TypeCheck(r_src, &PyUSMArrayType)) { - throw std::runtime_error("Result is not a dpctl.tensor.usm_ndarray"); - } - - PyUSMArrayObject *m_usm_ary = reinterpret_cast(m_src); - PyUSMArrayObject *v_usm_ary = reinterpret_cast(v_src); - PyUSMArrayObject *r_usm_ary = reinterpret_cast(r_src); - - if (UsmNDArray_GetNDim(m_usm_ary) != 2 || - UsmNDArray_GetNDim(v_usm_ary) != 1 || - UsmNDArray_GetNDim(r_usm_ary) != 1) - { + if (matrix.get_ndim() != 2 || vector.get_ndim() != 1 || + result.get_ndim() != 1) { throw std::runtime_error( "Inconsistent dimensions, expecting matrix and a vector"); } - py::ssize_t *m_sh = UsmNDArray_GetShape(m_usm_ary); - py::ssize_t n = m_sh[0]; - py::ssize_t m = m_sh[1]; + py::ssize_t n = matrix.get_shape(0); // get 0-th element of the shape + py::ssize_t m = matrix.get_shape(1); - py::ssize_t *v_sh = UsmNDArray_GetShape(v_usm_ary); - py::ssize_t *r_sh = UsmNDArray_GetShape(r_usm_ary); - if (v_sh[0] != m || r_sh[0] != n) { + py::ssize_t v_dim = vector.get_shape(0); + py::ssize_t r_dim = result.get_shape(0); + if (v_dim != m || r_dim != n) { throw std::runtime_error("Inconsistent shapes."); } - int mat_flags = UsmNDArray_GetFlags(m_usm_ary); - int v_flags = UsmNDArray_GetFlags(v_usm_ary); - int r_flags = UsmNDArray_GetFlags(r_usm_ary); + int mat_flags = matrix.get_flags(); + int v_flags = vector.get_flags(); + int r_flags = result.get_flags(); if (!((mat_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) && (v_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) && @@ -100,9 +68,9 @@ gemv(sycl::queue q, throw std::runtime_error("Arrays must be contiguous."); } - int mat_typenum = UsmNDArray_GetTypenum(m_usm_ary); - int v_typenum = UsmNDArray_GetTypenum(v_usm_ary); - int r_typenum = UsmNDArray_GetTypenum(r_usm_ary); + int mat_typenum = matrix.get_typenum(); + int v_typenum = vector.get_typenum(); + int r_typenum = result.get_typenum(); if ((mat_typenum != v_typenum) || (r_typenum != v_typenum) || !((v_typenum == UAR_DOUBLE) || (v_typenum == UAR_FLOAT) || @@ -116,9 +84,9 @@ gemv(sycl::queue q, "Only real and complex floating point arrays are supported."); } - char *mat_typeless_ptr = UsmNDArray_GetData(m_usm_ary); - char *v_typeless_ptr = UsmNDArray_GetData(v_usm_ary); - char *r_typeless_ptr = UsmNDArray_GetData(r_usm_ary); + char *mat_typeless_ptr = matrix.get_data(); + char *v_typeless_ptr = vector.get_data(); + char *r_typeless_ptr = result.get_data(); sycl::event res_ev; if (v_typenum == UAR_DOUBLE) {