Skip to content

Commit f0f57e8

Browse files
committed
Add arithmetic operators w/o tests
1 parent c13a59f commit f0f57e8

File tree

2 files changed

+114
-53
lines changed

2 files changed

+114
-53
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 113 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,25 @@
33
import array as py_array
44
import ctypes
55
from dataclasses import dataclass
6+
from typing import Any
67

78
from arrayfire import backend, safe_call # TODO refactoring
89
from arrayfire.array import _in_display_dims_limit # TODO refactoring
910

10-
from ._dtypes import CShape, Dtype, c_dim_t, float32, supported_dtypes
11-
from ._utils import Device, PointerSource, to_str
11+
from ._dtypes import CShape, Dtype
12+
from ._dtypes import bool as af_bool
13+
from ._dtypes import c_dim_t
14+
from ._dtypes import complex64 as af_complex64
15+
from ._dtypes import complex128 as af_complex128
16+
from ._dtypes import float32 as af_float32
17+
from ._dtypes import float64 as af_float64
18+
from ._dtypes import int64 as af_int64
19+
from ._dtypes import supported_dtypes
20+
from ._dtypes import uint64 as af_uint64
21+
from ._utils import PointerSource, is_number, to_str
1222

1323
ShapeType = tuple[int, ...]
24+
_bcast_var = False # HACK, TODO replace for actual bcast_var after refactoring
1425

1526

1627
@dataclass
@@ -40,7 +51,7 @@ def __init__(
4051

4152
if dtype is None:
4253
_no_initial_dtype = True
43-
dtype = float32
54+
dtype = af_float32
4455

4556
if x is None:
4657
if not shape: # shape is None or empty tuple
@@ -134,15 +145,47 @@ def __neg__(self) -> Array:
134145
"""
135146
Return -self
136147
"""
137-
# return 0 - self
138-
raise NotImplementedError
148+
return 0 - self
139149

140150
def __add__(self, other: int | float | Array, /) -> Array:
151+
# TODO discuss either we need to support complex and bool as other input type
141152
"""
142153
Return self + other.
143154
"""
144-
# return _binary_func(self, other, backend.get().af_add) # TODO
145-
raise NotImplementedError
155+
return _process_c_function(self, other, backend.get().af_add)
156+
157+
def __sub__(self, other: int | float | bool | complex | Array, /) -> Array:
158+
"""
159+
Return self - other.
160+
"""
161+
return _process_c_function(self, other, backend.get().af_sub)
162+
163+
def __mul__(self, other: int | float | bool | complex | Array, /) -> Array:
164+
"""
165+
Return self * other.
166+
"""
167+
return _process_c_function(self, other, backend.get().af_mul)
168+
169+
def __truediv__(self, other: int | float | bool | complex | Array, /) -> Array:
170+
"""
171+
Return self / other.
172+
"""
173+
return _process_c_function(self, other, backend.get().af_div)
174+
175+
def __floordiv__(self, other: int | float | bool | complex | Array, /) -> Array:
176+
return NotImplemented
177+
178+
def __mod__(self, other: int | float | bool | complex | Array, /) -> Array:
179+
"""
180+
Return self % other.
181+
"""
182+
return _process_c_function(self, other, backend.get().af_mod)
183+
184+
def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
185+
"""
186+
Return self ** other.
187+
"""
188+
return _process_c_function(self, other, backend.get().af_pow)
146189

147190
@property
148191
def dtype(self) -> Dtype:
@@ -151,7 +194,7 @@ def dtype(self) -> Dtype:
151194
return _c_api_value_to_dtype(out.value)
152195

153196
@property
154-
def device(self) -> Device:
197+
def device(self) -> Any:
155198
raise NotImplementedError
156199

157200
@property
@@ -232,41 +275,66 @@ def _str_to_dtype(value: int) -> Dtype:
232275

233276
raise TypeError("There is no supported dtype that matches passed dtype typecode.")
234277

235-
# TODO
236-
# def _binary_func(lhs: int | float | Array, rhs: int | float | Array, c_func: Any) -> Array: # TODO replace Any
237-
# out = Array()
238-
# other = rhs
239-
240-
# if is_number(rhs):
241-
# ldims = _fill_dim4_tuple(lhs.shape)
242-
# rty = implicit_dtype(rhs, lhs.type())
243-
# other = Array()
244-
# other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty.value)
245-
# elif not isinstance(rhs, Array):
246-
# raise TypeError("Invalid parameter to binary function")
247-
248-
# safe_call(c_func(c_pointer(out.arr), lhs.arr, other.arr, _bcast_var.get()))
249-
250-
# return out
251-
252-
253-
# TODO replace candidate below
254-
# def dim4_to_tuple(shape: ShapeType, default: int=1) -> ShapeType:
255-
# assert(isinstance(dims, tuple))
256-
257-
# if (default is not None):
258-
# assert(is_number(default))
259-
260-
# out = [default]*4
261-
262-
# for i, dim in enumerate(dims):
263-
# out[i] = dim
264-
265-
# return tuple(out)
266-
267-
# def _fill_dim4_tuple(shape: ShapeType) -> tuple[int, ...]:
268-
# out = tuple([1 if value is None else value for value in shape])
269-
# if len(out) == 4:
270-
# return out
271278

272-
# return out + (1,)*(4-len(out))
279+
def _process_c_function(
280+
target: Array, other: int | float | bool | complex | Array, c_function: Any) -> Array:
281+
out = Array()
282+
283+
if isinstance(other, Array):
284+
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other.arr, _bcast_var))
285+
elif is_number(other):
286+
target_c_shape = CShape(*target.shape)
287+
other_dtype = _implicit_dtype(other, target.dtype)
288+
other_array = _constant_array(other, target_c_shape, other_dtype)
289+
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other_array.arr, _bcast_var))
290+
else:
291+
raise TypeError(f"{type(other)} is not supported and can not be passed to C binary function.")
292+
293+
return out
294+
295+
296+
def _implicit_dtype(value: int | float | bool | complex, array_dtype: Dtype) -> Dtype:
297+
if isinstance(value, bool):
298+
value_dtype = af_bool
299+
if isinstance(value, int):
300+
value_dtype = af_int64
301+
elif isinstance(value, float):
302+
value_dtype = af_float64
303+
elif isinstance(value, complex):
304+
value_dtype = af_complex128
305+
else:
306+
raise TypeError(f"{type(value)} is not supported and can not be converted to af.Dtype.")
307+
308+
if not (array_dtype == af_float32 or array_dtype == af_complex64):
309+
return value_dtype
310+
311+
if value_dtype == af_float64:
312+
return af_float32
313+
314+
if value_dtype == af_complex128:
315+
return af_complex64
316+
317+
return value_dtype
318+
319+
320+
def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: Dtype) -> Array:
321+
out = Array()
322+
323+
if isinstance(value, complex):
324+
if dtype != af_complex64 and dtype != af_complex128:
325+
dtype = af_complex64
326+
327+
safe_call(backend.get().af_constant_complex(
328+
ctypes.pointer(out.arr), ctypes.c_double(value.real), ctypes.c_double(value.imag), 4,
329+
ctypes.pointer(shape.c_array), dtype))
330+
elif dtype == af_int64:
331+
safe_call(backend.get().af_constant_long(
332+
ctypes.pointer(out.arr), ctypes.c_longlong(value.real), 4, ctypes.pointer(shape.c_array)))
333+
elif dtype == af_uint64:
334+
safe_call(backend.get().af_constant_ulong(
335+
ctypes.pointer(out.arr), ctypes.c_ulonglong(value.real), 4, ctypes.pointer(shape.c_array)))
336+
else:
337+
safe_call(backend.get().af_constant(
338+
ctypes.pointer(out.arr), ctypes.c_double(value), 4, ctypes.pointer(shape.c_array), dtype))
339+
340+
return out

arrayfire/array_api/_utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
import ctypes
22
import enum
33
import numbers
4-
from typing import Any
5-
6-
7-
class Device(enum.Enum):
8-
# HACK. TODO make it real
9-
cpu = "cpu"
10-
gpu = "gpu"
114

125

136
class PointerSource(enum.Enum):
@@ -23,5 +16,5 @@ def to_str(c_str: ctypes.c_char_p) -> str:
2316
return str(c_str.value.decode("utf-8")) # type: ignore[union-attr]
2417

2518

26-
def is_number(number: Any) -> bool:
19+
def is_number(number: int | float | bool | complex) -> bool:
2720
return isinstance(number, numbers.Number)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy