Skip to content

Commit cdb7a92

Browse files
committed
Added to_list and to_ctypes_array
1 parent 4187b27 commit cdb7a92

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,20 @@ def __init__(
124124
ctypes.pointer(_cshape.c_array), ctypes.pointer(strides_cshape), dtype.c_api_value,
125125
pointer_source.value))
126126

127-
def __str__(self) -> str: # FIXME
127+
def __str__(self) -> str:
128+
# TODO change the look of array str. E.g., like np.array
128129
if not _in_display_dims_limit(self.shape):
129130
return _metadata_string(self.dtype, self.shape)
130131

131132
return _metadata_string(self.dtype) + _array_as_str(self)
132133

133-
def __repr__(self) -> str: # FIXME
134+
def __repr__(self) -> str:
134135
# return _metadata_string(self.dtype, self.shape)
135136
# TODO change the look of array representation. E.g., like np.array
136137
return _array_as_str(self)
137138

138139
def __len__(self) -> int:
139-
return self.shape[0] if self.shape else 0 # type: ignore[return-value]
140+
return self.shape[0] if self.shape else 0
140141

141142
# Arithmetic Operators
142143

@@ -475,17 +476,17 @@ def T(self) -> Array:
475476
raise NotImplementedError
476477

477478
@property
478-
def size(self) -> None | int:
479+
def size(self) -> int:
479480
# NOTE previously - elements()
480481
out = c_dim_t(0)
481482
safe_call(backend.get().af_get_elements(ctypes.pointer(out), self.arr))
482483
return out.value
483484

484485
@property
485486
def ndim(self) -> int:
486-
nd = ctypes.c_uint(0)
487-
safe_call(backend.get().af_get_numdims(ctypes.pointer(nd), self.arr))
488-
return nd.value
487+
out = ctypes.c_uint(0)
488+
safe_call(backend.get().af_get_numdims(ctypes.pointer(out), self.arr))
489+
return out.value
489490

490491
@property
491492
def shape(self) -> ShapeType:
@@ -510,6 +511,62 @@ def scalar(self) -> int | float | bool | complex:
510511
safe_call(backend.get().af_get_scalar(ctypes.pointer(out), self.arr))
511512
return out.value # type: ignore[no-any-return] # FIXME
512513

514+
def is_empty(self) -> bool:
515+
"""
516+
Check if the array is empty i.e. it has no elements.
517+
"""
518+
out = ctypes.c_bool()
519+
safe_call(backend.get().af_is_empty(ctypes.pointer(out), self.arr))
520+
return out.value
521+
522+
def to_list(self, row_major: bool = False) -> list: # FIXME return typings
523+
if self.is_empty():
524+
return []
525+
526+
array = _reorder(self) if row_major else self
527+
ctypes_array = _get_ctypes_array(array)
528+
529+
if array.ndim == 1:
530+
return list(ctypes_array)
531+
532+
out = []
533+
for i in range(array.size):
534+
idx = i
535+
sub_list = []
536+
for j in range(array.ndim):
537+
div = array.shape[j]
538+
sub_list.append(idx % div)
539+
idx //= div
540+
out.append(ctypes_array[sub_list[::-1]]) # type: ignore[call-overload] # FIXME
541+
return out
542+
543+
def to_ctype_array(self, row_major: bool = False) -> ctypes.Array:
544+
if self.is_empty():
545+
raise RuntimeError("Can not convert an empty array to ctype.")
546+
547+
array = _reorder(self) if row_major else self
548+
return _get_ctypes_array(array)
549+
550+
551+
def _get_ctypes_array(array: Array) -> ctypes.Array:
552+
c_shape = array.dtype.c_type * array.size
553+
ctypes_array = c_shape()
554+
safe_call(backend.get().af_get_data_ptr(ctypes.pointer(ctypes_array), array.arr))
555+
return ctypes_array
556+
557+
558+
def _reorder(array: Array) -> Array:
559+
"""
560+
Returns a reordered array to help interoperate with row major formats.
561+
"""
562+
if array.ndim == 1:
563+
return array
564+
565+
out = Array()
566+
c_shape = CShape(*(tuple(reversed(range(array.ndim))) + tuple(range(array.ndim, 4))))
567+
safe_call(backend.get().af_reorder(ctypes.pointer(out.arr), array.arr, *c_shape))
568+
return out
569+
513570

514571
def _array_as_str(array: Array) -> str:
515572
arr_str = ctypes.c_char_p(0)

arrayfire/array_api/tests/test_array_object.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from arrayfire.array_api import Array, float32, int16
44
from arrayfire.array_api._dtypes import supported_dtypes
55

6+
# TODO change separated methods with setup and teardown to avoid code duplication
7+
68

79
def test_empty_array() -> None:
810
array = Array()
@@ -105,7 +107,13 @@ def test_array_getitem() -> None:
105107
# TODO add more tests for different dtypes
106108

107109

108-
def test_array_sum() -> None:
110+
def test_array_to_list() -> None:
111+
# TODO add test of to_ctypes_array
112+
assert Array([1, 2, 3]).to_list() == [1, 2, 3]
113+
assert Array().to_list() == []
114+
115+
116+
def test_array_add() -> None:
109117
array = Array([1, 2, 3])
110118
res = array + 1
111119
assert res[0].scalar() == 2
@@ -123,6 +131,11 @@ def test_array_sum() -> None:
123131
assert res[2].scalar() == 12
124132

125133

134+
def test_array_add_raises_type_error() -> None:
135+
with pytest.raises(TypeError):
136+
Array([1, 2, 3]) + "15" # type: ignore[operator]
137+
138+
126139
def test_array_sub() -> None:
127140
array = Array([1, 2, 3])
128141
res = array - 1

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy