Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions dpctl/memory/_memory.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ from dpctl._backend cimport ( # noqa: E211
DPCTLaligned_alloc_device,
DPCTLaligned_alloc_host,
DPCTLaligned_alloc_shared,
DPCTLContext_AreEq,
DPCTLContext_Delete,
DPCTLDevice_Copy,
DPCTLEvent_Delete,
Expand Down Expand Up @@ -422,8 +423,10 @@ cdef class _Memory:
the memory of the instance
"""
cdef _USMBufferData src_buf
cdef const char* kind
cdef DPCTLSyclEventRef ERef = NULL
cdef bint same_contexts = False
cdef SyclQueue this_queue = None
cdef SyclQueue src_queue = None

if not hasattr(sycl_usm_ary, '__sycl_usm_array_interface__'):
raise ValueError(
Expand All @@ -439,23 +442,28 @@ cdef class _Memory:
"Source object is too large to "
"be accommondated in {} bytes buffer".format(self.nbytes)
)
kind = DPCTLUSM_GetPointerType(
src_buf.p, self.queue.get_sycl_context().get_context_ref())
if (kind == b'unknown'):
copy_via_host(
<void *>self.memory_ptr, self.queue, # dest
<void *>src_buf.p, src_buf.queue, # src
<size_t>src_buf.nbytes

src_queue = src_buf.queue
this_queue = self.queue
same_contexts = DPCTLContext_AreEq(
src_queue.get_sycl_context().get_context_ref(),
this_queue.get_sycl_context().get_context_ref()
)
else:
if (same_contexts):
ERef = DPCTLQueue_Memcpy(
self.queue.get_queue_ref(),
this_queue.get_queue_ref(),
<void *>self.memory_ptr,
<void *>src_buf.p,
<size_t>src_buf.nbytes
)
DPCTLEvent_Wait(ERef)
DPCTLEvent_Delete(ERef)
else:
copy_via_host(
<void *>self.memory_ptr, this_queue, # dest
<void *>src_buf.p, src_queue, # src
<size_t>src_buf.nbytes
)
else:
raise TypeError

Expand Down