11
11
Array class and helper functions.
12
12
"""
13
13
14
+ from .algorithm import sum , count
15
+ from .arith import cast
14
16
import inspect
15
17
import os
16
18
from .library import *
25
27
26
28
_display_dims_limit = None
27
29
30
+
28
31
def set_display_dims_limit (* dims ):
29
32
"""
30
33
Sets the dimension limit after which array's data won't get
@@ -44,6 +47,7 @@ def set_display_dims_limit(*dims):
44
47
global _display_dims_limit
45
48
_display_dims_limit = dims
46
49
50
+
47
51
def get_display_dims_limit ():
48
52
"""
49
53
Gets the dimension limit after which array's data won't get
@@ -67,6 +71,7 @@ def get_display_dims_limit():
67
71
"""
68
72
return _display_dims_limit
69
73
74
+
70
75
def _in_display_dims_limit (dims ):
71
76
if _is_running_in_py_charm :
72
77
return False
@@ -80,6 +85,7 @@ def _in_display_dims_limit(dims):
80
85
return False
81
86
return True
82
87
88
+
83
89
def _create_array (buf , numdims , idims , dtype , is_device ):
84
90
out_arr = c_void_ptr_t (0 )
85
91
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
@@ -91,6 +97,7 @@ def _create_array(buf, numdims, idims, dtype, is_device):
91
97
numdims , c_pointer (c_dims ), dtype .value ))
92
98
return out_arr
93
99
100
+
94
101
def _create_strided_array (buf , numdims , idims , dtype , is_device , offset , strides ):
95
102
out_arr = c_void_ptr_t (0 )
96
103
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
@@ -112,16 +119,15 @@ def _create_strided_array(buf, numdims, idims, dtype, is_device, offset, strides
112
119
location .value ))
113
120
return out_arr
114
121
122
+
115
123
def _create_empty_array (numdims , idims , dtype ):
116
124
out_arr = c_void_ptr_t (0 )
117
-
118
- if numdims == 0 : return out_arr
119
-
120
125
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
121
126
safe_call (backend .get ().af_create_handle (c_pointer (out_arr ),
122
127
numdims , c_pointer (c_dims ), dtype .value ))
123
128
return out_arr
124
129
130
+
125
131
def constant_array (val , d0 , d1 = None , d2 = None , d3 = None , dtype = Dtype .f32 ):
126
132
"""
127
133
Internal function to create a C array. Should not be used externall.
@@ -176,6 +182,7 @@ def _binary_func(lhs, rhs, c_func):
176
182
177
183
return out
178
184
185
+
179
186
def _binary_funcr (lhs , rhs , c_func ):
180
187
out = Array ()
181
188
other = lhs
@@ -192,9 +199,10 @@ def _binary_funcr(lhs, rhs, c_func):
192
199
193
200
return out
194
201
202
+
195
203
def _ctype_to_lists (ctype_arr , dim , shape , offset = 0 ):
196
204
if (dim == 0 ):
197
- return list (ctype_arr [offset : offset + shape [0 ]])
205
+ return list (ctype_arr [offset : offset + shape [0 ]])
198
206
else :
199
207
dim_len = shape [dim ]
200
208
res = [[]] * dim_len
@@ -203,6 +211,7 @@ def _ctype_to_lists(ctype_arr, dim, shape, offset=0):
203
211
offset += shape [0 ]
204
212
return res
205
213
214
+
206
215
def _slice_to_length (key , dim ):
207
216
tkey = [key .start , key .stop , key .step ]
208
217
@@ -221,6 +230,7 @@ def _slice_to_length(key, dim):
221
230
222
231
return int (((tkey [1 ] - tkey [0 ] - 1 ) / tkey [2 ]) + 1 )
223
232
233
+
224
234
def _get_info (dims , buf_len ):
225
235
elements = 1
226
236
numdims = 0
@@ -250,6 +260,7 @@ def _get_indices(key):
250
260
251
261
return inds
252
262
263
+
253
264
def _get_assign_dims (key , idims ):
254
265
255
266
dims = [1 ]* 4
@@ -296,6 +307,7 @@ def _get_assign_dims(key, idims):
296
307
else :
297
308
raise IndexError ("Invalid type while assigning to arrayfire.array" )
298
309
310
+
299
311
def transpose (a , conj = False ):
300
312
"""
301
313
Perform the transpose on an input.
@@ -318,6 +330,7 @@ def transpose(a, conj=False):
318
330
safe_call (backend .get ().af_transpose (c_pointer (out .arr ), a .arr , conj ))
319
331
return out
320
332
333
+
321
334
def transpose_inplace (a , conj = False ):
322
335
"""
323
336
Perform inplace transpose on an input.
@@ -338,6 +351,7 @@ def transpose_inplace(a, conj=False):
338
351
"""
339
352
safe_call (backend .get ().af_transpose_inplace (a .arr , conj ))
340
353
354
+
341
355
class Array (BaseArray ):
342
356
343
357
"""
@@ -447,8 +461,8 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
447
461
448
462
super (Array , self ).__init__ ()
449
463
450
- buf = None
451
- buf_len = 0
464
+ buf = None
465
+ buf_len = 0
452
466
453
467
if dtype is not None :
454
468
if isinstance (dtype , str ):
@@ -458,7 +472,7 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
458
472
else :
459
473
type_char = None
460
474
461
- _type_char = 'f'
475
+ _type_char = 'f'
462
476
463
477
if src is not None :
464
478
@@ -469,12 +483,12 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
469
483
host = __import__ ("array" )
470
484
471
485
if isinstance (src , host .array ):
472
- buf ,buf_len = src .buffer_info ()
486
+ buf , buf_len = src .buffer_info ()
473
487
_type_char = src .typecode
474
488
numdims , idims = _get_info (dims , buf_len )
475
489
elif isinstance (src , list ):
476
490
tmp = host .array ('f' , src )
477
- buf ,buf_len = tmp .buffer_info ()
491
+ buf , buf_len = tmp .buffer_info ()
478
492
_type_char = tmp .typecode
479
493
numdims , idims = _get_info (dims , buf_len )
480
494
elif isinstance (src , int ) or isinstance (src , c_void_ptr_t ):
@@ -498,7 +512,7 @@ def __init__(self, src=None, dims=None, dtype=None, is_device=False, offset=None
498
512
raise TypeError ("src is an object of unsupported class" )
499
513
500
514
if (type_char is not None and
501
- type_char != _type_char ):
515
+ type_char != _type_char ):
502
516
raise TypeError ("Can not create array of requested type from input data type" )
503
517
if (offset is None and strides is None ):
504
518
self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
@@ -620,8 +634,8 @@ def strides(self):
620
634
s2 = c_dim_t (0 )
621
635
s3 = c_dim_t (0 )
622
636
safe_call (backend .get ().af_get_strides (c_pointer (s0 ), c_pointer (s1 ),
623
- c_pointer (s2 ), c_pointer (s3 ), self .arr ))
624
- strides = (s0 .value ,s1 .value ,s2 .value ,s3 .value )
637
+ c_pointer (s2 ), c_pointer (s3 ), self .arr ))
638
+ strides = (s0 .value , s1 .value , s2 .value , s3 .value )
625
639
return strides [:self .numdims ()]
626
640
627
641
def elements (self ):
@@ -680,8 +694,8 @@ def dims(self):
680
694
d2 = c_dim_t (0 )
681
695
d3 = c_dim_t (0 )
682
696
safe_call (backend .get ().af_get_dims (c_pointer (d0 ), c_pointer (d1 ),
683
- c_pointer (d2 ), c_pointer (d3 ), self .arr ))
684
- dims = (d0 .value ,d1 .value ,d2 .value ,d3 .value )
697
+ c_pointer (d2 ), c_pointer (d3 ), self .arr ))
698
+ dims = (d0 .value , d1 .value , d2 .value , d3 .value )
685
699
return dims [:self .numdims ()]
686
700
687
701
@property
@@ -906,7 +920,7 @@ def __itruediv__(self, other):
906
920
"""
907
921
Perform self /= other.
908
922
"""
909
- self = _binary_func (self , other , backend .get ().af_div )
923
+ self = _binary_func (self , other , backend .get ().af_div )
910
924
return self
911
925
912
926
def __rtruediv__ (self , other ):
@@ -925,7 +939,7 @@ def __idiv__(self, other):
925
939
"""
926
940
Perform other / self.
927
941
"""
928
- self = _binary_func (self , other , backend .get ().af_div )
942
+ self = _binary_func (self , other , backend .get ().af_div )
929
943
return self
930
944
931
945
def __rdiv__ (self , other ):
@@ -944,7 +958,7 @@ def __imod__(self, other):
944
958
"""
945
959
Perform self %= other.
946
960
"""
947
- self = _binary_func (self , other , backend .get ().af_mod )
961
+ self = _binary_func (self , other , backend .get ().af_mod )
948
962
return self
949
963
950
964
def __rmod__ (self , other ):
@@ -963,7 +977,7 @@ def __ipow__(self, other):
963
977
"""
964
978
Perform self **= other.
965
979
"""
966
- self = _binary_func (self , other , backend .get ().af_pow )
980
+ self = _binary_func (self , other , backend .get ().af_pow )
967
981
return self
968
982
969
983
def __rpow__ (self , other ):
@@ -1106,15 +1120,15 @@ def logical_and(self, other):
1106
1120
Return self && other.
1107
1121
"""
1108
1122
out = Array ()
1109
- safe_call (backend .get ().af_and (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1123
+ safe_call (backend .get ().af_and (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1110
1124
return out
1111
1125
1112
1126
def logical_or (self , other ):
1113
1127
"""
1114
1128
Return self || other.
1115
1129
"""
1116
1130
out = Array ()
1117
- safe_call (backend .get ().af_or (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1131
+ safe_call (backend .get ().af_or (c_pointer (out .arr ), self .arr , other .arr )) # TODO: bcast var?
1118
1132
return out
1119
1133
1120
1134
def __nonzero__ (self ):
@@ -1144,12 +1158,11 @@ def __getitem__(self, key):
1144
1158
inds = _get_indices (key )
1145
1159
1146
1160
safe_call (backend .get ().af_index_gen (c_pointer (out .arr ),
1147
- self .arr , c_dim_t (n_dims ), inds .pointer ))
1161
+ self .arr , c_dim_t (n_dims ), inds .pointer ))
1148
1162
return out
1149
1163
except RuntimeError as e :
1150
1164
raise IndexError (str (e ))
1151
1165
1152
-
1153
1166
def __setitem__ (self , key , val ):
1154
1167
"""
1155
1168
Perform self[key] = val
@@ -1175,14 +1188,14 @@ def __setitem__(self, key, val):
1175
1188
n_dims = 1
1176
1189
other_arr = constant_array (val , int (num ), dtype = self .type ())
1177
1190
else :
1178
- other_arr = constant_array (val , tdims [0 ] , tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1191
+ other_arr = constant_array (val , tdims [0 ], tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1179
1192
del_other = True
1180
1193
else :
1181
1194
other_arr = val .arr
1182
1195
del_other = False
1183
1196
1184
1197
out_arr = c_void_ptr_t (0 )
1185
- inds = _get_indices (key )
1198
+ inds = _get_indices (key )
1186
1199
1187
1200
safe_call (backend .get ().af_assign_gen (c_pointer (out_arr ),
1188
1201
self .arr , c_dim_t (n_dims ), inds .pointer ,
@@ -1401,6 +1414,7 @@ def to_ndarray(self, output=None):
1401
1414
safe_call (backend .get ().af_get_data_ptr (c_void_ptr_t (output .ctypes .data ), tmp .arr ))
1402
1415
return output
1403
1416
1417
+
1404
1418
def display (a , precision = 4 ):
1405
1419
"""
1406
1420
Displays the contents of an array.
@@ -1426,6 +1440,7 @@ def display(a, precision=4):
1426
1440
safe_call (backend .get ().af_print_array_gen (name .encode ('utf-8' ),
1427
1441
a .arr , c_int_t (precision )))
1428
1442
1443
+
1429
1444
def save_array (key , a , filename , append = False ):
1430
1445
"""
1431
1446
Save an array to disk.
@@ -1457,6 +1472,7 @@ def save_array(key, a, filename, append=False):
1457
1472
append ))
1458
1473
return index .value
1459
1474
1475
+
1460
1476
def read_array (filename , index = None , key = None ):
1461
1477
"""
1462
1478
Read an array from disk.
@@ -1490,6 +1506,3 @@ def read_array(filename, index=None, key=None):
1490
1506
key .encode ('utf-8' )))
1491
1507
1492
1508
return out
1493
-
1494
- from .algorithm import (sum , count )
1495
- from .arith import cast
0 commit comments