Skip to content

Commit b14aa91

Browse files
committed
Replace dim4 with CShape
1 parent 75c1d43 commit b14aa91

File tree

5 files changed

+83
-71
lines changed

5 files changed

+83
-71
lines changed

arrayfire/array_api/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@
66
"complex64", "complex128", "bool"]
77

88
from ._array_object import Array
9-
from ._dtypes import (
10-
bool, complex64, complex128, float32, float64, int16, int32, int64, uint8, uint16, uint32, uint64)
9+
from ._dtypes import bool, complex64, complex128, float32, float64, int16, int32, int64, uint8, uint16, uint32, uint64

arrayfire/array_api/_array_object.py

Lines changed: 25 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
import array as py_array
44
import ctypes
5-
import math
65
from dataclasses import dataclass
76

87
from arrayfire import backend, safe_call # TODO refactoring
98
from arrayfire.array import _in_display_dims_limit # TODO refactoring
109

11-
from ._dtypes import Dtype, c_dim_t, float32, supported_dtypes
10+
from ._dtypes import CShape, Dtype, c_dim_t, float32, supported_dtypes
1211
from ._utils import Device, PointerSource, to_str
1312

1413
ShapeType = tuple[None | int, ...]
@@ -28,7 +27,6 @@ class Array:
2827
__array_priority__ = 30
2928

3029
# Initialisation
31-
_array_buffer = _ArrayBuffer()
3230
arr = ctypes.c_void_p(0)
3331

3432
def __init__(
@@ -46,12 +44,12 @@ def __init__(
4644
if x is None:
4745
if not shape: # shape is None or empty tuple
4846
safe_call(backend.get().af_create_handle(
49-
ctypes.pointer(self.arr), 0, ctypes.pointer(dim4()), dtype.c_api_value))
47+
ctypes.pointer(self.arr), 0, ctypes.pointer(CShape().c_array), dtype.c_api_value))
5048
return
5149

5250
# NOTE: applies inplace changes for self.arr
5351
safe_call(backend.get().af_create_handle(
54-
ctypes.pointer(self.arr), len(shape), ctypes.pointer(dim4(*shape)), dtype.c_api_value))
52+
ctypes.pointer(self.arr), len(shape), ctypes.pointer(CShape(*shape).c_array), dtype.c_api_value))
5553
return
5654

5755
if isinstance(x, Array):
@@ -61,19 +59,16 @@ def __init__(
6159
if isinstance(x, py_array.array):
6260
_type_char = x.typecode
6361
_array_buffer = _ArrayBuffer(*x.buffer_info())
64-
numdims, idims = _get_info(shape, _array_buffer.length)
6562

6663
elif isinstance(x, list):
6764
_array = py_array.array("f", x) # BUG [True, False] -> dtype: f32 # TODO add int and float
6865
_type_char = _array.typecode
6966
_array_buffer = _ArrayBuffer(*_array.buffer_info())
70-
numdims, idims = _get_info(shape, _array_buffer.length)
7167

7268
elif isinstance(x, int) or isinstance(x, ctypes.c_void_p): # TODO
7369
_array_buffer = _ArrayBuffer(x if not isinstance(x, ctypes.c_void_p) else x.value)
74-
numdims, idims = _get_info(shape, _array_buffer.length)
7570

76-
if not math.prod(idims):
71+
if not shape:
7772
raise RuntimeError("Expected to receive the initial shape due to the x being a data pointer.")
7873

7974
if _no_initial_dtype:
@@ -84,34 +79,37 @@ def __init__(
8479
else:
8580
raise TypeError("Passed object x is an object of unsupported class.")
8681

82+
_cshape = _get_cshape(shape, _array_buffer.length)
83+
8784
if not _no_initial_dtype and dtype.typecode != _type_char:
8885
raise TypeError("Can not create array of requested type from input data type")
8986

9087
if not (offset or strides):
9188
if pointer_source == PointerSource.host:
9289
safe_call(backend.get().af_create_array(
93-
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), numdims,
94-
ctypes.pointer(dim4(*idims)), dtype.c_api_value))
90+
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), _cshape.original_shape,
91+
ctypes.pointer(_cshape.c_array), dtype.c_api_value))
9592
return
9693

9794
safe_call(backend.get().af_device_array(
98-
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), numdims,
99-
ctypes.pointer(dim4(*idims)), dtype.c_api_value))
95+
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), _cshape.original_shape,
96+
ctypes.pointer(_cshape.c_array), dtype.c_api_value))
10097
return
10198

102-
if offset is None: # TODO
99+
if offset is None:
103100
offset = c_dim_t(0)
104101

105-
if strides is None: # TODO
106-
strides = (1, idims[0], idims[0]*idims[1], idims[0]*idims[1]*idims[2])
102+
if strides is None:
103+
strides = (1, _cshape[0], _cshape[0]*_cshape[1], _cshape[0]*_cshape[1]*_cshape[2])
107104

108105
if len(strides) < 4:
109106
strides += (strides[-1], ) * (4 - len(strides))
110-
strides_dim4 = dim4(*strides)
107+
strides_cshape = CShape(*strides).c_array
111108

112109
safe_call(backend.get().af_create_strided_array(
113-
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), offset, numdims,
114-
ctypes.pointer(dim4(*idims)), ctypes.pointer(strides_dim4), dtype.c_api_value, pointer_source.value))
110+
ctypes.pointer(self.arr), ctypes.c_void_p(_array_buffer.address), offset, _cshape.original_shape,
111+
ctypes.pointer(_cshape.c_array), ctypes.pointer(strides_cshape), dtype.c_api_value,
112+
pointer_source.value))
115113

116114
def __str__(self) -> str: # FIXME
117115
if not _in_display_dims_limit(self.shape):
@@ -126,7 +124,7 @@ def __len__(self) -> int:
126124
return self.shape[0] if self.shape else 0 # type: ignore[return-value]
127125

128126
def __pos__(self) -> Array:
129-
"""y
127+
"""
130128
Return +self
131129
"""
132130
return self
@@ -190,8 +188,7 @@ def shape(self) -> ShapeType:
190188
d3 = c_dim_t(0)
191189
safe_call(backend.get().af_get_dims(
192190
ctypes.pointer(d0), ctypes.pointer(d1), ctypes.pointer(d2), ctypes.pointer(d3), self.arr))
193-
dims = (d0.value, d1.value, d2.value, d3.value)
194-
return dims[:self.ndim] # FIXME An array dimension must be None if and only if a dimension is unknown
191+
return (d0.value, d1.value, d2.value, d3.value)[:self.ndim] # Skip passing None values
195192

196193
def _as_str(self) -> str:
197194
arr_str = ctypes.c_char_p(0)
@@ -201,30 +198,6 @@ def _as_str(self) -> str:
201198
safe_call(backend.get().af_free_host(arr_str))
202199
return py_str
203200

204-
# def _get_metadata_str(self, show_dims: bool = True) -> str:
205-
# return (
206-
# "arrayfire.Array()\n"
207-
# f"Type: {self.dtype.typename}\n"
208-
# f"Dims: {str(self._dims) if show_dims else ''}")
209-
210-
# @property
211-
# def dtype(self) -> ...:
212-
# dty = ctypes.c_int()
213-
# safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
214-
215-
# @safe_call
216-
# def backend()
217-
# ...
218-
219-
# @backend(safe=True)
220-
# def af_get_type(arr) -> ...:
221-
# dty = ctypes.c_int()
222-
# safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
223-
# return dty
224-
225-
# def new_dtype():
226-
# return af_get_type(self.arr)
227-
228201

229202
def _metadata_string(dtype: Dtype, dims: None | ShapeType = None) -> str:
230203
return (
@@ -233,20 +206,14 @@ def _metadata_string(dtype: Dtype, dims: None | ShapeType = None) -> str:
233206
f"Dims: {str(dims) if dims else ''}")
234207

235208

236-
def _get_info(shape: None | tuple[int], buffer_length: int) -> tuple[int, list[int]]:
237-
# TODO refactor
209+
def _get_cshape(shape: None | tuple[int], buffer_length: int) -> CShape:
238210
if shape:
239-
numdims = len(shape)
240-
idims = [1]*4
241-
for i in range(numdims):
242-
idims[i] = shape[i]
243-
elif (buffer_length != 0):
244-
idims = [buffer_length, 1, 1, 1]
245-
numdims = 1
246-
else:
247-
raise RuntimeError("Invalid size")
211+
return CShape(*shape)
212+
213+
if buffer_length != 0:
214+
return CShape(buffer_length)
248215

249-
return numdims, idims
216+
raise RuntimeError("Shape and buffer length are size invalid.")
250217

251218

252219
def _c_api_value_to_dtype(value: int) -> Dtype:
@@ -282,16 +249,6 @@ def _str_to_dtype(value: int) -> Dtype:
282249
# return out
283250

284251

285-
def dim4(d0: int = 1, d1: int = 1, d2: int = 1, d3: int = 1): # type: ignore # FIXME
286-
c_dim4 = c_dim_t * 4 # ctypes.c_int | ctypes.c_longlong * 4
287-
out = c_dim4(1, 1, 1, 1)
288-
289-
for i, dim in enumerate((d0, d1, d2, d3)):
290-
if dim is not None:
291-
out[i] = c_dim_t(dim)
292-
293-
return out
294-
295252
# TODO replace candidate below
296253
# def dim4_to_tuple(shape: ShapeType, default: int=1) -> ShapeType:
297254
# assert(isinstance(dims, tuple))

arrayfire/array_api/_dtypes.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import ctypes
24
from dataclasses import dataclass
35
from typing import Type
@@ -31,6 +33,39 @@ class Dtype:
3133
bool = Dtype("b", ctypes.c_bool, "bool", 4)
3234

3335
supported_dtypes = [
34-
# int8,
3536
int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128, bool
3637
]
38+
39+
40+
class CShape(tuple):
41+
def __new__(cls, *args: int) -> CShape:
42+
cls.original_shape = len(args)
43+
return tuple.__new__(cls, args)
44+
45+
def __init__(self, x1: int = 1, x2: int = 1, x3: int = 1, x4: int = 1) -> None:
46+
self.x1 = x1
47+
self.x2 = x2
48+
self.x3 = x3
49+
self.x4 = x4
50+
51+
def __repr__(self) -> str:
52+
return f"{self.__class__.__name__}{self.x1, self.x2, self.x3, self.x4}"
53+
54+
@property
55+
def c_array(self): # type: ignore[no-untyped-def]
56+
c_shape = c_dim_t * 4 # ctypes.c_int | ctypes.c_longlong * 4
57+
return c_shape(c_dim_t(self.x1), c_dim_t(self.x2), c_dim_t(self.x3), c_dim_t(self.x4))
58+
59+
60+
# @safe_call
61+
# def backend()
62+
# ...
63+
64+
# @backend(safe=True)
65+
# def af_get_type(arr) -> ...:
66+
# dty = ctypes.c_int()
67+
# safe_call(backend.get().af_get_type(ctypes.pointer(dty), self.arr)) # -> new dty
68+
# return dty
69+
70+
# def new_dtype():
71+
# return af_get_type(self.arr)

arrayfire/array_api/pytest.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[pytest]
2+
addopts = --cache-clear --cov=./arrayfire/array_api --flake8 --mypy --isort ./arrayfire/array_api
3+
console_output_style = classic
4+
markers = mypy

arrayfire/array_api/tests/test_array.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
from arrayfire.array_api import Array, float32
24

35

@@ -9,3 +11,18 @@ def test_empty_array() -> None:
911
assert array.size == 0
1012
assert array.shape == ()
1113
assert len(array) == 0
14+
15+
16+
def test_array_from_1d_list() -> None:
17+
array = Array([1, 2, 3])
18+
19+
assert array.dtype == float32
20+
assert array.ndim == 1
21+
assert array.size == 3
22+
assert array.shape == (3,)
23+
assert len(array) == 3
24+
25+
26+
def test_array_from_2d_list() -> None:
27+
with pytest.raises(TypeError):
28+
Array([[1, 2, 3], [1, 2, 3]])

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy