3
3
import array as py_array
4
4
import ctypes
5
5
from dataclasses import dataclass
6
+ from typing import Any
6
7
7
8
from arrayfire import backend , safe_call # TODO refactoring
8
9
from arrayfire .array import _in_display_dims_limit # TODO refactoring
9
10
10
- from ._dtypes import CShape , Dtype , c_dim_t , float32 , supported_dtypes
11
- from ._utils import Device , PointerSource , to_str
11
+ from ._dtypes import CShape , Dtype
12
+ from ._dtypes import bool as af_bool
13
+ from ._dtypes import c_dim_t
14
+ from ._dtypes import complex64 as af_complex64
15
+ from ._dtypes import complex128 as af_complex128
16
+ from ._dtypes import float32 as af_float32
17
+ from ._dtypes import float64 as af_float64
18
+ from ._dtypes import int64 as af_int64
19
+ from ._dtypes import supported_dtypes
20
+ from ._dtypes import uint64 as af_uint64
21
+ from ._utils import PointerSource , is_number , to_str
12
22
13
23
ShapeType = tuple [int , ...]
24
+ _bcast_var = False # HACK, TODO replace for actual bcast_var after refactoring
14
25
15
26
16
27
@dataclass
@@ -40,7 +51,7 @@ def __init__(
40
51
41
52
if dtype is None :
42
53
_no_initial_dtype = True
43
- dtype = float32
54
+ dtype = af_float32
44
55
45
56
if x is None :
46
57
if not shape : # shape is None or empty tuple
@@ -134,15 +145,47 @@ def __neg__(self) -> Array:
134
145
"""
135
146
Return -self
136
147
"""
137
- # return 0 - self
138
- raise NotImplementedError
148
+ return 0 - self
139
149
140
150
def __add__ (self , other : int | float | Array , / ) -> Array :
151
+ # TODO discuss either we need to support complex and bool as other input type
141
152
"""
142
153
Return self + other.
143
154
"""
144
- # return _binary_func(self, other, backend.get().af_add) # TODO
145
- raise NotImplementedError
155
+ return _process_c_function (self , other , backend .get ().af_add )
156
+
157
+ def __sub__ (self , other : int | float | bool | complex | Array , / ) -> Array :
158
+ """
159
+ Return self - other.
160
+ """
161
+ return _process_c_function (self , other , backend .get ().af_sub )
162
+
163
+ def __mul__ (self , other : int | float | bool | complex | Array , / ) -> Array :
164
+ """
165
+ Return self * other.
166
+ """
167
+ return _process_c_function (self , other , backend .get ().af_mul )
168
+
169
+ def __truediv__ (self , other : int | float | bool | complex | Array , / ) -> Array :
170
+ """
171
+ Return self / other.
172
+ """
173
+ return _process_c_function (self , other , backend .get ().af_div )
174
+
175
+ def __floordiv__ (self , other : int | float | bool | complex | Array , / ) -> Array :
176
+ return NotImplemented
177
+
178
+ def __mod__ (self , other : int | float | bool | complex | Array , / ) -> Array :
179
+ """
180
+ Return self % other.
181
+ """
182
+ return _process_c_function (self , other , backend .get ().af_mod )
183
+
184
+ def __pow__ (self , other : int | float | bool | complex | Array , / ) -> Array :
185
+ """
186
+ Return self ** other.
187
+ """
188
+ return _process_c_function (self , other , backend .get ().af_pow )
146
189
147
190
@property
148
191
def dtype (self ) -> Dtype :
@@ -151,7 +194,7 @@ def dtype(self) -> Dtype:
151
194
return _c_api_value_to_dtype (out .value )
152
195
153
196
@property
154
- def device (self ) -> Device :
197
+ def device (self ) -> Any :
155
198
raise NotImplementedError
156
199
157
200
@property
@@ -232,41 +275,66 @@ def _str_to_dtype(value: int) -> Dtype:
232
275
233
276
raise TypeError ("There is no supported dtype that matches passed dtype typecode." )
234
277
235
- # TODO
236
- # def _binary_func(lhs: int | float | Array, rhs: int | float | Array, c_func: Any) -> Array: # TODO replace Any
237
- # out = Array()
238
- # other = rhs
239
-
240
- # if is_number(rhs):
241
- # ldims = _fill_dim4_tuple(lhs.shape)
242
- # rty = implicit_dtype(rhs, lhs.type())
243
- # other = Array()
244
- # other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty.value)
245
- # elif not isinstance(rhs, Array):
246
- # raise TypeError("Invalid parameter to binary function")
247
-
248
- # safe_call(c_func(c_pointer(out.arr), lhs.arr, other.arr, _bcast_var.get()))
249
-
250
- # return out
251
-
252
-
253
- # TODO replace candidate below
254
- # def dim4_to_tuple(shape: ShapeType, default: int=1) -> ShapeType:
255
- # assert(isinstance(dims, tuple))
256
-
257
- # if (default is not None):
258
- # assert(is_number(default))
259
-
260
- # out = [default]*4
261
-
262
- # for i, dim in enumerate(dims):
263
- # out[i] = dim
264
-
265
- # return tuple(out)
266
-
267
- # def _fill_dim4_tuple(shape: ShapeType) -> tuple[int, ...]:
268
- # out = tuple([1 if value is None else value for value in shape])
269
- # if len(out) == 4:
270
- # return out
271
278
272
- # return out + (1,)*(4-len(out))
279
+ def _process_c_function (
280
+ target : Array , other : int | float | bool | complex | Array , c_function : Any ) -> Array :
281
+ out = Array ()
282
+
283
+ if isinstance (other , Array ):
284
+ safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other .arr , _bcast_var ))
285
+ elif is_number (other ):
286
+ target_c_shape = CShape (* target .shape )
287
+ other_dtype = _implicit_dtype (other , target .dtype )
288
+ other_array = _constant_array (other , target_c_shape , other_dtype )
289
+ safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other_array .arr , _bcast_var ))
290
+ else :
291
+ raise TypeError (f"{ type (other )} is not supported and can not be passed to C binary function." )
292
+
293
+ return out
294
+
295
+
296
+ def _implicit_dtype (value : int | float | bool | complex , array_dtype : Dtype ) -> Dtype :
297
+ if isinstance (value , bool ):
298
+ value_dtype = af_bool
299
+ if isinstance (value , int ):
300
+ value_dtype = af_int64
301
+ elif isinstance (value , float ):
302
+ value_dtype = af_float64
303
+ elif isinstance (value , complex ):
304
+ value_dtype = af_complex128
305
+ else :
306
+ raise TypeError (f"{ type (value )} is not supported and can not be converted to af.Dtype." )
307
+
308
+ if not (array_dtype == af_float32 or array_dtype == af_complex64 ):
309
+ return value_dtype
310
+
311
+ if value_dtype == af_float64 :
312
+ return af_float32
313
+
314
+ if value_dtype == af_complex128 :
315
+ return af_complex64
316
+
317
+ return value_dtype
318
+
319
+
320
+ def _constant_array (value : int | float | bool | complex , shape : CShape , dtype : Dtype ) -> Array :
321
+ out = Array ()
322
+
323
+ if isinstance (value , complex ):
324
+ if dtype != af_complex64 and dtype != af_complex128 :
325
+ dtype = af_complex64
326
+
327
+ safe_call (backend .get ().af_constant_complex (
328
+ ctypes .pointer (out .arr ), ctypes .c_double (value .real ), ctypes .c_double (value .imag ), 4 ,
329
+ ctypes .pointer (shape .c_array ), dtype ))
330
+ elif dtype == af_int64 :
331
+ safe_call (backend .get ().af_constant_long (
332
+ ctypes .pointer (out .arr ), ctypes .c_longlong (value .real ), 4 , ctypes .pointer (shape .c_array )))
333
+ elif dtype == af_uint64 :
334
+ safe_call (backend .get ().af_constant_ulong (
335
+ ctypes .pointer (out .arr ), ctypes .c_ulonglong (value .real ), 4 , ctypes .pointer (shape .c_array )))
336
+ else :
337
+ safe_call (backend .get ().af_constant (
338
+ ctypes .pointer (out .arr ), ctypes .c_double (value ), 4 , ctypes .pointer (shape .c_array ), dtype ))
339
+
340
+ return out
0 commit comments