Skip to content

Commit 8cef774

Browse files
committed
Fix array init bug. Add __getitem__. Change pytest for active debug mode
1 parent f0f57e8 commit 8cef774

File tree

3 files changed

+66
-18
lines changed

3 files changed

+66
-18
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from dataclasses import dataclass
66
from typing import Any
77

8-
from arrayfire import backend, safe_call # TODO refactoring
9-
from arrayfire.array import _in_display_dims_limit # TODO refactoring
8+
from arrayfire import backend, safe_call # TODO refactor
9+
from arrayfire.algorithm import count # TODO refactor
10+
from arrayfire.array import _get_indices, _in_display_dims_limit # TODO refactor
1011

1112
from ._dtypes import CShape, Dtype
1213
from ._dtypes import bool as af_bool
@@ -37,15 +38,15 @@ class Array:
3738
# arrayfire's __radd__() instead of numpy's __add__()
3839
__array_priority__ = 30
3940

40-
# Initialisation
41-
arr = ctypes.c_void_p(0)
42-
4341
def __init__(
4442
self, x: None | Array | py_array.array | int | ctypes.c_void_p | list = None, dtype: None | Dtype = None,
4543
pointer_source: PointerSource = PointerSource.host, shape: None | ShapeType = None,
4644
offset: None | ctypes._SimpleCData[int] = None, strides: None | ShapeType = None) -> None:
4745
_no_initial_dtype = False # HACK, FIXME
4846

47+
# Initialise array object
48+
self.arr = ctypes.c_void_p(0)
49+
4950
if isinstance(dtype, str):
5051
dtype = _str_to_dtype(dtype)
5152

@@ -127,7 +128,7 @@ def __str__(self) -> str: # FIXME
127128
if not _in_display_dims_limit(self.shape):
128129
return _metadata_string(self.dtype, self.shape)
129130

130-
return _metadata_string(self.dtype) + self._as_str()
131+
return _metadata_string(self.dtype) + _array_as_str(self)
131132

132133
def __repr__(self) -> str: # FIXME
133134
return _metadata_string(self.dtype, self.shape)
@@ -173,6 +174,7 @@ def __truediv__(self, other: int | float | bool | complex | Array, /) -> Array:
173174
return _process_c_function(self, other, backend.get().af_div)
174175

175176
def __floordiv__(self, other: int | float | bool | complex | Array, /) -> Array:
177+
# TODO
176178
return NotImplemented
177179

178180
def __mod__(self, other: int | float | bool | complex | Array, /) -> Array:
@@ -187,6 +189,25 @@ def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
187189
"""
188190
return _process_c_function(self, other, backend.get().af_pow)
189191

192+
def __matmul__(self, other: Array, /) -> Array:
193+
# TODO
194+
return NotImplemented
195+
196+
def __getitem__(self, key: int | slice | tuple[int | slice] | Array, /) -> Array:
197+
# TODO: API Specification - key: int | slice | ellipsis | tuple[int | slice] | Array
198+
# TODO: refactor
199+
out = Array()
200+
ndims = self.ndim
201+
202+
if isinstance(key, Array) and key == af_bool.c_api_value:
203+
ndims = 1
204+
if count(key) == 0:
205+
return out
206+
207+
safe_call(backend.get().af_index_gen(
208+
ctypes.pointer(out.arr), self.arr, c_dim_t(ndims), _get_indices(key).pointer))
209+
return out
210+
190211
@property
191212
def dtype(self) -> Dtype:
192213
out = ctypes.c_int()
@@ -234,13 +255,23 @@ def shape(self) -> ShapeType:
234255
ctypes.pointer(d0), ctypes.pointer(d1), ctypes.pointer(d2), ctypes.pointer(d3), self.arr))
235256
return (d0.value, d1.value, d2.value, d3.value)[:self.ndim] # Skip passing None values
236257

237-
def _as_str(self) -> str:
238-
arr_str = ctypes.c_char_p(0)
239-
# FIXME add description to passed arguments
240-
safe_call(backend.get().af_array_to_string(ctypes.pointer(arr_str), "", self.arr, 4, True))
241-
py_str = to_str(arr_str)
242-
safe_call(backend.get().af_free_host(arr_str))
243-
return py_str
258+
def scalar(self) -> int | float | bool | complex:
259+
"""
260+
Return the first element of the array
261+
"""
262+
# BUG seg fault on empty array
263+
out = self.dtype.c_type()
264+
safe_call(backend.get().af_get_scalar(ctypes.pointer(out), self.arr))
265+
return out.value # type: ignore[no-any-return] # FIXME
266+
267+
268+
def _array_as_str(array: Array) -> str:
269+
arr_str = ctypes.c_char_p(0)
270+
# FIXME add description to passed arguments
271+
safe_call(backend.get().af_array_to_string(ctypes.pointer(arr_str), "", array.arr, 4, True))
272+
py_str = to_str(arr_str)
273+
safe_call(backend.get().af_free_host(arr_str))
274+
return py_str
244275

245276

246277
def _metadata_string(dtype: Dtype, dims: None | ShapeType = None) -> str:
@@ -283,9 +314,8 @@ def _process_c_function(
283314
if isinstance(other, Array):
284315
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other.arr, _bcast_var))
285316
elif is_number(other):
286-
target_c_shape = CShape(*target.shape)
287317
other_dtype = _implicit_dtype(other, target.dtype)
288-
other_array = _constant_array(other, target_c_shape, other_dtype)
318+
other_array = _constant_array(other, CShape(*target.shape), other_dtype)
289319
safe_call(c_function(ctypes.pointer(out.arr), target.arr, other_array.arr, _bcast_var))
290320
else:
291321
raise TypeError(f"{type(other)} is not supported and can not be passed to C binary function.")
@@ -326,7 +356,7 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
326356

327357
safe_call(backend.get().af_constant_complex(
328358
ctypes.pointer(out.arr), ctypes.c_double(value.real), ctypes.c_double(value.imag), 4,
329-
ctypes.pointer(shape.c_array), dtype))
359+
ctypes.pointer(shape.c_array), dtype.c_api_value))
330360
elif dtype == af_int64:
331361
safe_call(backend.get().af_constant_long(
332362
ctypes.pointer(out.arr), ctypes.c_longlong(value.real), 4, ctypes.pointer(shape.c_array)))
@@ -335,6 +365,6 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
335365
ctypes.pointer(out.arr), ctypes.c_ulonglong(value.real), 4, ctypes.pointer(shape.c_array)))
336366
else:
337367
safe_call(backend.get().af_constant(
338-
ctypes.pointer(out.arr), ctypes.c_double(value), 4, ctypes.pointer(shape.c_array), dtype))
368+
ctypes.pointer(out.arr), ctypes.c_double(value), 4, ctypes.pointer(shape.c_array), dtype.c_api_value))
339369

340370
return out

arrayfire/array_api/pytest.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[pytest]
2-
addopts = --cache-clear --cov=./arrayfire/array_api --flake8 --isort ./arrayfire/array_api
2+
addopts = --cache-clear --cov=./arrayfire/array_api --flake8 --isort -s ./arrayfire/array_api
33
console_output_style = classic
44
markers = mypy

arrayfire/array_api/tests/test_array_object.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,21 @@ def test_array_from_unsupported_type() -> None:
9393

9494
with pytest.raises(TypeError):
9595
Array({1: 2, 3: 4}) # type: ignore[arg-type]
96+
97+
98+
def test_array_getitem() -> None:
99+
array = Array([1, 2, 3, 4, 5])
100+
101+
int_item = array[2]
102+
assert array.dtype == int_item.dtype
103+
assert int_item.scalar() == 3
104+
105+
# TODO add more tests for different dtypes
106+
107+
108+
# def test_array_sum() -> None: # BUG no element-wise adding
109+
# array = Array([1, 2, 3])
110+
# res = array + 1
111+
# assert res.scalar() == 2
112+
# assert res.scalar() == 3
113+
# assert res.scalar() == 4

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy