From f15651bbfa6fb6e601315e62e9ac0a9977b2758e Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 8 Apr 2021 14:12:42 -0500 Subject: [PATCH] Added get_queue_ref_from_ptr_and_syclobj Implements DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj( DPCTLSyclUSMRef ptr, object syclobj) This function is to help users of __sycl_usm_array_interface__ to create a queue from a pointer and a valid syclobj entry in the interface. Currently supported variants are: - filter selector string : creates queue for the device created from the selector - dpctl.SyclQueue : use given queue - dpctl.SyclContext : find device from ptr and context, create queue from context : and device - capsule with SyclQueueRef : use this queue - capsule with SyclContextRef : use this context to recover device and create queue - any python object that implements method _get_capsule() : use that capsule as outlined above --- dpctl/memory/_memory.pxd | 5 ++- dpctl/memory/_memory.pyx | 78 ++++++++++++++++++++++++++++++++-------- 2 files changed, 67 insertions(+), 16 deletions(-) diff --git a/dpctl/memory/_memory.pxd b/dpctl/memory/_memory.pxd index 9aab4dcee3..d44081d260 100644 --- a/dpctl/memory/_memory.pxd +++ b/dpctl/memory/_memory.pxd @@ -22,12 +22,15 @@ in dpctl.memory._memory.pyx. """ -from .._backend cimport DPCTLSyclUSMRef +from .._backend cimport DPCTLSyclUSMRef, DPCTLSyclQueueRef from .._sycl_context cimport SyclContext from .._sycl_device cimport SyclDevice from .._sycl_queue cimport SyclQueue +cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj( + DPCTLSyclUSMRef ptr, object syclobj) + cdef class _Memory: cdef DPCTLSyclUSMRef memory_ptr cdef Py_ssize_t nbytes diff --git a/dpctl/memory/_memory.pyx b/dpctl/memory/_memory.pyx index a74eacb493..dfa96f8f31 100644 --- a/dpctl/memory/_memory.pyx +++ b/dpctl/memory/_memory.pyx @@ -32,6 +32,7 @@ from .._sycl_queue_manager cimport get_current_queue from cpython cimport Py_buffer from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_FromStringAndSize +from cpython cimport pycapsule import numpy as np @@ -41,10 +42,63 @@ __all__ = [ "MemoryUSMDevice" ] -cdef _throw_sycl_usm_ary_iface(): - raise ValueError("__sycl_usm_array_interface__ is malformed") +cdef object _sycl_usm_ary_iface_error(): + return ValueError("__sycl_usm_array_interface__ is malformed") +cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(SyclQueue q): + return DPCTLQueue_Copy(q.get_queue_ref()) + + +cdef DPCTLSyclQueueRef _queue_ref_copy_from_USMRef_and_SyclContext( + DPCTLSyclUSMRef ptr, SyclContext ctx): + """ Obtain device from pointer and sycl context, use + context and device to create a queue from which this memory + can be accessible. + """ + cdef SyclDevice dev = _Memory.get_pointer_device(ptr, ctx) + cdef DPCTLSyclContextRef CRef = NULL + cdef DPCTLSyclDeviceRef DRef = NULL + CRef = ctx.get_context_ref() + DRef = dev.get_device_ref() + return DPCTLQueue_Create(CRef, DRef, NULL, 0) + + +cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj( + DPCTLSyclUSMRef ptr, object syclobj): + """ Constructs queue from pointer and syclobject from + __sycl_usm_array_interface__ + """ + cdef DPCTLSyclQueueRef QRef = NULL + cdef SyclContext ctx + if type(syclobj) is SyclQueue: + return _queue_ref_copy_from_SyclQueue( syclobj) + elif type(syclobj) is SyclContext: + ctx = syclobj + return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx) + elif type(syclobj) is str: + q = SyclQueue(syclobj) + return _queue_ref_copy_from_SyclQueue( q) + elif pycapsule.PyCapsule_IsValid(syclobj, "SyclQueueRef"): + q = SyclQueue(syclobj) + return _queue_ref_copy_from_SyclQueue( q) + elif pycapsule.PyCapsule_IsValid(syclobj, "SyclContextRef"): + ctx = SyclContext(syclobj) + return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx) + elif hasattr(syclobj, '_get_capsule'): + cap = syclobj._get_capsule() + if pycapsule.PyCapsule_IsValid(cap, "SyclQueueRef"): + q = SyclQueue(cap) + return _queue_ref_copy_from_SyclQueue( q) + elif pycapsule.PyCapsule_IsValid(cap, "SyclContexRef"): + ctx = SyclContext(cap) + return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx) + else: + return QRef + else: + return QRef + + cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue, void *src_ptr, SyclQueue src_queue, size_t nbytes): """ @@ -98,17 +152,18 @@ cdef class _BufferData: cdef Py_ssize_t arr_data_ptr cdef SyclDevice dev cdef SyclContext ctx + cdef DPCTLSyclQueueRef QRef = NULL if ary_version != 1: - _throw_sycl_usm_ary_iface() + raise _sycl_usm_ary_iface_error() if not ary_data_tuple or len(ary_data_tuple) != 2: - _throw_sycl_usm_ary_iface() + raise _sycl_usm_ary_iface_error() if not ary_shape or len(ary_shape) != 1 or ary_shape[0] < 1: raise ValueError try: dt = np.dtype(ary_typestr) except TypeError: - _throw_sycl_usm_ary_iface() + raise _sycl_usm_ary_iface_error() if (ary_strides and len(ary_strides) != 1 and ary_strides[0] != dt.itemsize): raise ValueError("Must be contiguous") @@ -116,7 +171,7 @@ cdef class _BufferData: if (not ary_syclobj or not isinstance(ary_syclobj, (dpctl.SyclQueue, dpctl.SyclContext))): - _throw_sycl_usm_ary_iface() + raise _sycl_usm_ary_iface_error() buf = _BufferData.__new__(_BufferData) arr_data_ptr = ary_data_tuple[0] @@ -125,15 +180,8 @@ cdef class _BufferData: buf.itemsize = (dt.itemsize) buf.nbytes = (ary_shape[0]) * buf.itemsize - if isinstance(ary_syclobj, dpctl.SyclQueue): - buf.queue = ary_syclobj - else: - # Obtain device from pointer and context - ctx = ary_syclobj - dev = _Memory.get_pointer_device(buf.p, ctx) - # Use context and device to create a queue to - # be able to copy memory - buf.queue = SyclQueue._create_from_context_and_device(ctx, dev) + QRef = get_queue_ref_from_ptr_and_syclobj(buf.p, ary_syclobj) + buf.queue = SyclQueue._create(QRef) return buf