5
5
from dataclasses import dataclass
6
6
from typing import Any
7
7
8
- from arrayfire import backend , safe_call # TODO refactoring
9
- from arrayfire .array import _in_display_dims_limit # TODO refactoring
8
+ from arrayfire import backend , safe_call # TODO refactor
9
+ from arrayfire .algorithm import count # TODO refactor
10
+ from arrayfire .array import _get_indices , _in_display_dims_limit # TODO refactor
10
11
11
12
from ._dtypes import CShape , Dtype
12
13
from ._dtypes import bool as af_bool
@@ -37,15 +38,15 @@ class Array:
37
38
# arrayfire's __radd__() instead of numpy's __add__()
38
39
__array_priority__ = 30
39
40
40
- # Initialisation
41
- arr = ctypes .c_void_p (0 )
42
-
43
41
def __init__ (
44
42
self , x : None | Array | py_array .array | int | ctypes .c_void_p | list = None , dtype : None | Dtype = None ,
45
43
pointer_source : PointerSource = PointerSource .host , shape : None | ShapeType = None ,
46
44
offset : None | ctypes ._SimpleCData [int ] = None , strides : None | ShapeType = None ) -> None :
47
45
_no_initial_dtype = False # HACK, FIXME
48
46
47
+ # Initialise array object
48
+ self .arr = ctypes .c_void_p (0 )
49
+
49
50
if isinstance (dtype , str ):
50
51
dtype = _str_to_dtype (dtype )
51
52
@@ -127,7 +128,7 @@ def __str__(self) -> str: # FIXME
127
128
if not _in_display_dims_limit (self .shape ):
128
129
return _metadata_string (self .dtype , self .shape )
129
130
130
- return _metadata_string (self .dtype ) + self . _as_str ( )
131
+ return _metadata_string (self .dtype ) + _array_as_str ( self )
131
132
132
133
def __repr__ (self ) -> str : # FIXME
133
134
return _metadata_string (self .dtype , self .shape )
@@ -173,6 +174,7 @@ def __truediv__(self, other: int | float | bool | complex | Array, /) -> Array:
173
174
return _process_c_function (self , other , backend .get ().af_div )
174
175
175
176
def __floordiv__ (self , other : int | float | bool | complex | Array , / ) -> Array :
177
+ # TODO
176
178
return NotImplemented
177
179
178
180
def __mod__ (self , other : int | float | bool | complex | Array , / ) -> Array :
@@ -187,6 +189,25 @@ def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
187
189
"""
188
190
return _process_c_function (self , other , backend .get ().af_pow )
189
191
192
+ def __matmul__ (self , other : Array , / ) -> Array :
193
+ # TODO
194
+ return NotImplemented
195
+
196
+ def __getitem__ (self , key : int | slice | tuple [int | slice ] | Array , / ) -> Array :
197
+ # TODO: API Specification - key: int | slice | ellipsis | tuple[int | slice] | Array
198
+ # TODO: refactor
199
+ out = Array ()
200
+ ndims = self .ndim
201
+
202
+ if isinstance (key , Array ) and key == af_bool .c_api_value :
203
+ ndims = 1
204
+ if count (key ) == 0 :
205
+ return out
206
+
207
+ safe_call (backend .get ().af_index_gen (
208
+ ctypes .pointer (out .arr ), self .arr , c_dim_t (ndims ), _get_indices (key ).pointer ))
209
+ return out
210
+
190
211
@property
191
212
def dtype (self ) -> Dtype :
192
213
out = ctypes .c_int ()
@@ -234,13 +255,23 @@ def shape(self) -> ShapeType:
234
255
ctypes .pointer (d0 ), ctypes .pointer (d1 ), ctypes .pointer (d2 ), ctypes .pointer (d3 ), self .arr ))
235
256
return (d0 .value , d1 .value , d2 .value , d3 .value )[:self .ndim ] # Skip passing None values
236
257
237
- def _as_str (self ) -> str :
238
- arr_str = ctypes .c_char_p (0 )
239
- # FIXME add description to passed arguments
240
- safe_call (backend .get ().af_array_to_string (ctypes .pointer (arr_str ), "" , self .arr , 4 , True ))
241
- py_str = to_str (arr_str )
242
- safe_call (backend .get ().af_free_host (arr_str ))
243
- return py_str
258
+ def scalar (self ) -> int | float | bool | complex :
259
+ """
260
+ Return the first element of the array
261
+ """
262
+ # BUG seg fault on empty array
263
+ out = self .dtype .c_type ()
264
+ safe_call (backend .get ().af_get_scalar (ctypes .pointer (out ), self .arr ))
265
+ return out .value # type: ignore[no-any-return] # FIXME
266
+
267
+
268
+ def _array_as_str (array : Array ) -> str :
269
+ arr_str = ctypes .c_char_p (0 )
270
+ # FIXME add description to passed arguments
271
+ safe_call (backend .get ().af_array_to_string (ctypes .pointer (arr_str ), "" , array .arr , 4 , True ))
272
+ py_str = to_str (arr_str )
273
+ safe_call (backend .get ().af_free_host (arr_str ))
274
+ return py_str
244
275
245
276
246
277
def _metadata_string (dtype : Dtype , dims : None | ShapeType = None ) -> str :
@@ -283,9 +314,8 @@ def _process_c_function(
283
314
if isinstance (other , Array ):
284
315
safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other .arr , _bcast_var ))
285
316
elif is_number (other ):
286
- target_c_shape = CShape (* target .shape )
287
317
other_dtype = _implicit_dtype (other , target .dtype )
288
- other_array = _constant_array (other , target_c_shape , other_dtype )
318
+ other_array = _constant_array (other , CShape ( * target . shape ) , other_dtype )
289
319
safe_call (c_function (ctypes .pointer (out .arr ), target .arr , other_array .arr , _bcast_var ))
290
320
else :
291
321
raise TypeError (f"{ type (other )} is not supported and can not be passed to C binary function." )
@@ -326,7 +356,7 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
326
356
327
357
safe_call (backend .get ().af_constant_complex (
328
358
ctypes .pointer (out .arr ), ctypes .c_double (value .real ), ctypes .c_double (value .imag ), 4 ,
329
- ctypes .pointer (shape .c_array ), dtype ))
359
+ ctypes .pointer (shape .c_array ), dtype . c_api_value ))
330
360
elif dtype == af_int64 :
331
361
safe_call (backend .get ().af_constant_long (
332
362
ctypes .pointer (out .arr ), ctypes .c_longlong (value .real ), 4 , ctypes .pointer (shape .c_array )))
@@ -335,6 +365,6 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
335
365
ctypes .pointer (out .arr ), ctypes .c_ulonglong (value .real ), 4 , ctypes .pointer (shape .c_array )))
336
366
else :
337
367
safe_call (backend .get ().af_constant (
338
- ctypes .pointer (out .arr ), ctypes .c_double (value ), 4 , ctypes .pointer (shape .c_array ), dtype ))
368
+ ctypes .pointer (out .arr ), ctypes .c_double (value ), 4 , ctypes .pointer (shape .c_array ), dtype . c_api_value ))
339
369
340
370
return out
0 commit comments