2
2
3
3
import array as py_array
4
4
import ctypes
5
- import math
6
5
from dataclasses import dataclass
7
6
8
7
from arrayfire import backend , safe_call # TODO refactoring
9
8
from arrayfire .array import _in_display_dims_limit # TODO refactoring
10
9
11
- from ._dtypes import Dtype , c_dim_t , float32 , supported_dtypes
10
+ from ._dtypes import CShape , Dtype , c_dim_t , float32 , supported_dtypes
12
11
from ._utils import Device , PointerSource , to_str
13
12
14
13
ShapeType = tuple [None | int , ...]
@@ -28,7 +27,6 @@ class Array:
28
27
__array_priority__ = 30
29
28
30
29
# Initialisation
31
- _array_buffer = _ArrayBuffer ()
32
30
arr = ctypes .c_void_p (0 )
33
31
34
32
def __init__ (
@@ -46,12 +44,12 @@ def __init__(
46
44
if x is None :
47
45
if not shape : # shape is None or empty tuple
48
46
safe_call (backend .get ().af_create_handle (
49
- ctypes .pointer (self .arr ), 0 , ctypes .pointer (dim4 () ), dtype .c_api_value ))
47
+ ctypes .pointer (self .arr ), 0 , ctypes .pointer (CShape (). c_array ), dtype .c_api_value ))
50
48
return
51
49
52
50
# NOTE: applies inplace changes for self.arr
53
51
safe_call (backend .get ().af_create_handle (
54
- ctypes .pointer (self .arr ), len (shape ), ctypes .pointer (dim4 (* shape )), dtype .c_api_value ))
52
+ ctypes .pointer (self .arr ), len (shape ), ctypes .pointer (CShape (* shape ). c_array ), dtype .c_api_value ))
55
53
return
56
54
57
55
if isinstance (x , Array ):
@@ -61,19 +59,16 @@ def __init__(
61
59
if isinstance (x , py_array .array ):
62
60
_type_char = x .typecode
63
61
_array_buffer = _ArrayBuffer (* x .buffer_info ())
64
- numdims , idims = _get_info (shape , _array_buffer .length )
65
62
66
63
elif isinstance (x , list ):
67
64
_array = py_array .array ("f" , x ) # BUG [True, False] -> dtype: f32 # TODO add int and float
68
65
_type_char = _array .typecode
69
66
_array_buffer = _ArrayBuffer (* _array .buffer_info ())
70
- numdims , idims = _get_info (shape , _array_buffer .length )
71
67
72
68
elif isinstance (x , int ) or isinstance (x , ctypes .c_void_p ): # TODO
73
69
_array_buffer = _ArrayBuffer (x if not isinstance (x , ctypes .c_void_p ) else x .value )
74
- numdims , idims = _get_info (shape , _array_buffer .length )
75
70
76
- if not math . prod ( idims ) :
71
+ if not shape :
77
72
raise RuntimeError ("Expected to receive the initial shape due to the x being a data pointer." )
78
73
79
74
if _no_initial_dtype :
@@ -84,34 +79,37 @@ def __init__(
84
79
else :
85
80
raise TypeError ("Passed object x is an object of unsupported class." )
86
81
82
+ _cshape = _get_cshape (shape , _array_buffer .length )
83
+
87
84
if not _no_initial_dtype and dtype .typecode != _type_char :
88
85
raise TypeError ("Can not create array of requested type from input data type" )
89
86
90
87
if not (offset or strides ):
91
88
if pointer_source == PointerSource .host :
92
89
safe_call (backend .get ().af_create_array (
93
- ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), numdims ,
94
- ctypes .pointer (dim4 ( * idims ) ), dtype .c_api_value ))
90
+ ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), _cshape . original_shape ,
91
+ ctypes .pointer (_cshape . c_array ), dtype .c_api_value ))
95
92
return
96
93
97
94
safe_call (backend .get ().af_device_array (
98
- ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), numdims ,
99
- ctypes .pointer (dim4 ( * idims ) ), dtype .c_api_value ))
95
+ ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), _cshape . original_shape ,
96
+ ctypes .pointer (_cshape . c_array ), dtype .c_api_value ))
100
97
return
101
98
102
- if offset is None : # TODO
99
+ if offset is None :
103
100
offset = c_dim_t (0 )
104
101
105
- if strides is None : # TODO
106
- strides = (1 , idims [0 ], idims [0 ]* idims [1 ], idims [0 ]* idims [1 ]* idims [2 ])
102
+ if strides is None :
103
+ strides = (1 , _cshape [0 ], _cshape [0 ]* _cshape [1 ], _cshape [0 ]* _cshape [1 ]* _cshape [2 ])
107
104
108
105
if len (strides ) < 4 :
109
106
strides += (strides [- 1 ], ) * (4 - len (strides ))
110
- strides_dim4 = dim4 (* strides )
107
+ strides_cshape = CShape (* strides ). c_array
111
108
112
109
safe_call (backend .get ().af_create_strided_array (
113
- ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), offset , numdims ,
114
- ctypes .pointer (dim4 (* idims )), ctypes .pointer (strides_dim4 ), dtype .c_api_value , pointer_source .value ))
110
+ ctypes .pointer (self .arr ), ctypes .c_void_p (_array_buffer .address ), offset , _cshape .original_shape ,
111
+ ctypes .pointer (_cshape .c_array ), ctypes .pointer (strides_cshape ), dtype .c_api_value ,
112
+ pointer_source .value ))
115
113
116
114
def __str__ (self ) -> str : # FIXME
117
115
if not _in_display_dims_limit (self .shape ):
@@ -126,7 +124,7 @@ def __len__(self) -> int:
126
124
return self .shape [0 ] if self .shape else 0 # type: ignore[return-value]
127
125
128
126
def __pos__ (self ) -> Array :
129
- """y
127
+ """
130
128
Return +self
131
129
"""
132
130
return self
@@ -190,8 +188,7 @@ def shape(self) -> ShapeType:
190
188
d3 = c_dim_t (0 )
191
189
safe_call (backend .get ().af_get_dims (
192
190
ctypes .pointer (d0 ), ctypes .pointer (d1 ), ctypes .pointer (d2 ), ctypes .pointer (d3 ), self .arr ))
193
- dims = (d0 .value , d1 .value , d2 .value , d3 .value )
194
- return dims [:self .ndim ] # FIXME An array dimension must be None if and only if a dimension is unknown
191
+ return (d0 .value , d1 .value , d2 .value , d3 .value )[:self .ndim ] # Skip passing None values
195
192
196
193
def _as_str (self ) -> str :
197
194
arr_str = ctypes .c_char_p (0 )
@@ -201,30 +198,6 @@ def _as_str(self) -> str:
201
198
safe_call (backend .get ().af_free_host (arr_str ))
202
199
return py_str
203
200
204
- # def _get_metadata_str(self, show_dims: bool = True) -> str:
205
- # return (
206
- # "arrayfire.Array()\n"
207
- # f"Type: {self.dtype.typename}\n"
208
- # f"Dims: {str(self._dims) if show_dims else ''}")
209
-
210
- # @property
211
- # def dtype(self) -> ...:
212
- # dty = ctypes.c_int()
213
- # safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
214
-
215
- # @safe_call
216
- # def backend()
217
- # ...
218
-
219
- # @backend(safe=True)
220
- # def af_get_type(arr) -> ...:
221
- # dty = ctypes.c_int()
222
- # safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
223
- # return dty
224
-
225
- # def new_dtype():
226
- # return af_get_type(self.arr)
227
-
228
201
229
202
def _metadata_string (dtype : Dtype , dims : None | ShapeType = None ) -> str :
230
203
return (
@@ -233,20 +206,14 @@ def _metadata_string(dtype: Dtype, dims: None | ShapeType = None) -> str:
233
206
f"Dims: { str (dims ) if dims else '' } " )
234
207
235
208
236
- def _get_info (shape : None | tuple [int ], buffer_length : int ) -> tuple [int , list [int ]]:
237
- # TODO refactor
209
+ def _get_cshape (shape : None | tuple [int ], buffer_length : int ) -> CShape :
238
210
if shape :
239
- numdims = len (shape )
240
- idims = [1 ]* 4
241
- for i in range (numdims ):
242
- idims [i ] = shape [i ]
243
- elif (buffer_length != 0 ):
244
- idims = [buffer_length , 1 , 1 , 1 ]
245
- numdims = 1
246
- else :
247
- raise RuntimeError ("Invalid size" )
211
+ return CShape (* shape )
212
+
213
+ if buffer_length != 0 :
214
+ return CShape (buffer_length )
248
215
249
- return numdims , idims
216
+ raise RuntimeError ( "Shape and buffer length are size invalid." )
250
217
251
218
252
219
def _c_api_value_to_dtype (value : int ) -> Dtype :
@@ -282,16 +249,6 @@ def _str_to_dtype(value: int) -> Dtype:
282
249
# return out
283
250
284
251
285
- def dim4 (d0 : int = 1 , d1 : int = 1 , d2 : int = 1 , d3 : int = 1 ): # type: ignore # FIXME
286
- c_dim4 = c_dim_t * 4 # ctypes.c_int | ctypes.c_longlong * 4
287
- out = c_dim4 (1 , 1 , 1 , 1 )
288
-
289
- for i , dim in enumerate ((d0 , d1 , d2 , d3 )):
290
- if dim is not None :
291
- out [i ] = c_dim_t (dim )
292
-
293
- return out
294
-
295
252
# TODO replace candidate below
296
253
# def dim4_to_tuple(shape: ShapeType, default: int=1) -> ShapeType:
297
254
# assert(isinstance(dims, tuple))
0 commit comments