Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions dpctl/tensor/_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dpctl
Comment thread
oleksandr-pavlyk marked this conversation as resolved.


class Device:
"""
Class representing Data-API concept of device.

This is a wrapper around :class:`dpctl.SyclQueue` with custom
formatting. The class does not have public constructor,
but a class method to construct it from device= keyword
in Array-API functions.

Instance can be queried for ``sycl_queue``, ``sycl_context``,
or ``sycl_device``.
"""

def __new__(cls, *args, **kwargs):
raise TypeError("No public constructor")

@classmethod
def create_device(cls, dev):
"""
Device.create_device(device)

Creates instance of Device from argument.

Args:
device: None, :class:`.Device`, :class:`dpctl.SyclQueue`, or
a :class:`dpctl.SyclDevice` corresponding to a root
SYCL device.
Raises:
ValueError: if an instance of :class:`dpctl.SycDevice` corresponding
to a sub-device was specified as the argument
SyclQueueCreationError: if :class:`dpctl.SyclQueue` could not be
created from the argument
"""
obj = super().__new__(cls)
if isinstance(dev, Device):
obj.sycl_queue_ = dev.sycl_queue
elif isinstance(dev, dpctl.SyclQueue):
obj.sycl_queue_ = dev
elif isinstance(dev, dpctl.SyclDevice):
par = dev.parent_device
if par is None:
obj.sycl_queue_ = dpctl.SyclQueue(dev)
else:
raise ValueError(
"Using non-root device {} to specify offloading "
"target is ambiguous. Please use dpctl.SyclQueue "
"targeting this device".format(dev)
)
else:
obj.sycl_queue_ = dpctl.SyclQueue(dev)
return obj

@property
def sycl_queue(self):
"""
:class:`dpctl.SyclQueue` used to offload to this :class:`.Device`.
"""
return self.sycl_queue_

@property
def sycl_context(self):
"""
:class:`dpctl.SyclContext` associated with this :class:`.Device`.
"""
return self.sycl_queue_.sycl_context

@property
def sycl_device(self):
"""
:class:`dpctl.SyclDevice` targed by this :class:`.Device`.
"""
return self.sycl_queue_.sycl_device

def __repr__(self):
try:
sd = self.sycl_device
except AttributeError:
raise ValueError(
"Instance of {} is not initialized".format(self.__class__)
)
try:
fs = sd.filter_string
return "Device({})".format(fs)
except TypeError:
# This is a sub-device
return repr(self.sycl_queue)
46 changes: 44 additions & 2 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import numpy as np
import dpctl
import dpctl.memory as dpmem

from ._device import Device

from cpython.mem cimport PyMem_Free
from cpython.tuple cimport PyTuple_New, PyTuple_SetItem

Expand Down Expand Up @@ -181,8 +183,8 @@ cdef class usm_ndarray:
raise ValueError(
"buffer='{}' is not understood. "
"Recognized values are 'device', 'shared', 'host', "
"or an object with __sycl_usm_array_interface__ "
"property".format(buffer))
"an instance of `MemoryUSM*` object, or a usm_ndarray"
"".format(buffer))
elif isinstance(buffer, usm_ndarray):
_buffer = buffer.usm_data
else:
Expand Down Expand Up @@ -428,6 +430,13 @@ cdef class usm_ndarray:
q = self.sycl_queue
return q.sycl_device

@property
def device(self):
"""
Returns data-API object representing residence of the array data.
"""
return Device.create_device(self.sycl_queue)

@property
def sycl_context(self):
"""
Expand Down Expand Up @@ -475,6 +484,39 @@ cdef class usm_ndarray:
res.flags_ |= (self.flags_ & USM_ARRAY_WRITEABLE)
return res

def to_device(self, target_device):
"""
Transfer array to target device
"""
d = Device.create_device(target_device)
if (d.sycl_device == self.sycl_device):
return self
elif (d.sycl_context == self.sycl_context):
res = usm_ndarray(
self.shape,
self.dtype,
buffer=self.usm_data,
strides=self.strides,
offset=self.get_offset()
)
res.flags_ = self.flags
return res
else:
nbytes = self.usm_data.nbytes
new_buffer = type(self.usm_data)(
nbytes, queue=d.sycl_queue
)
new_buffer.copy_from_device(self.usm_data)
res = usm_ndarray(
self.shape,
self.dtype,
buffer=new_buffer,
strides=self.strides,
offset=self.get_offset()
)
res.flags_ = self.flags
return res


cdef usm_ndarray _real_view(usm_ndarray ary):
"""
Expand Down