This repository was archived by the owner on Jan 26, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathndarray.py
More file actions
68 lines (54 loc) · 2.09 KB
/
ndarray.py
File metadata and controls
68 lines (54 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
The array class for sharpy, a distributed implementation of the
array API as defined here: https://data-apis.org/array-api/latest
"""
#
# See __init__.py for an implementation overview
#
from . import _sharpy as _csp
from . import array_api as api
def slicefy(x):
if isinstance(x, slice):
return x
# slice that extracts a single element at index x
next_val = None if x == -1 else x + 1
return slice(x, next_val, 1)
class ndarray:
def __init__(self, t):
self._t = t
def __repr__(self):
return self._t.__repr__()
for method in api.api_categories["EWBinOp"]:
if method.startswith("__"):
METHOD = method.upper()
exec(
f"{method} = lambda self, other: ndarray(_csp.EWBinOp.op(_csp.{METHOD}, self._t, other._t if isinstance(other, ndarray) else other))"
)
for method in api.api_categories["EWUnyOp"]:
if method.startswith("__"):
METHOD = method.upper()
exec(
f"{method} = lambda self: ndarray(_csp.EWUnyOp.op(_csp.{METHOD}, self._t))"
)
for method in api.api_categories["UnyOp"]:
exec(f"{method} = lambda self: self._t.{method}()")
for att in api.attributes:
exec(f"{att} = property(lambda self: self._t.{att})")
def astype(self, dtype, copy=False):
return ndarray(self._t.astype(dtype, copy))
def to_device(self, device=""):
return ndarray(self._t.to_device(device))
def __getitem__(self, key):
key = key if isinstance(key, tuple) else (key,)
key = [slicefy(x) for x in key]
return ndarray(self._t.__getitem__(key))
def __setitem__(self, key, value):
key = key if isinstance(key, tuple) else (key,)
key = [slicefy(x) for x in key]
if isinstance(value, ndarray) and value._t.dtype != self._t.dtype:
raise ValueError(
f"Mismatching data type in setitem: {value._t.dtype}, expecting {self._t.dtype}"
)
self._t.__setitem__(
key, value._t if isinstance(value, ndarray) else value
)