Skip to content

Commit 769c16c

Browse files
committed
Fix typing in array object. Add tests
1 parent 9c0435a commit 769c16c

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,17 @@ class Array:
3939
__array_priority__ = 30
4040

4141
def __init__(
42-
self, x: None | Array | py_array.array | int | ctypes.c_void_p | list = None, dtype: None | Dtype = None,
43-
pointer_source: PointerSource = PointerSource.host, shape: None | ShapeType = None,
44-
offset: None | ctypes._SimpleCData[int] = None, strides: None | ShapeType = None) -> None:
42+
self, x: None | Array | py_array.array | int | ctypes.c_void_p | list = None,
43+
dtype: None | Dtype | str = None, shape: None | ShapeType = None,
44+
pointer_source: PointerSource = PointerSource.host, offset: None | ctypes._SimpleCData[int] = None,
45+
strides: None | ShapeType = None) -> None:
4546
_no_initial_dtype = False # HACK, FIXME
4647

4748
# Initialise array object
4849
self.arr = ctypes.c_void_p(0)
4950

5051
if isinstance(dtype, str):
51-
dtype = _str_to_dtype(dtype)
52+
dtype = _str_to_dtype(dtype) # type: ignore[arg-type]
5253

5354
if dtype is None:
5455
_no_initial_dtype = True

arrayfire/array_api/tests/test_array_object.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import array as pyarray
2+
13
import pytest
24

35
from arrayfire.array_api import Array, float32, int16
46
from arrayfire.array_api._dtypes import supported_dtypes
57

68
# TODO change separated methods with setup and teardown to avoid code duplication
9+
# TODO add tests for array arguments: device, offset, strides
710

811

9-
def test_empty_array() -> None:
12+
def test_create_empty_array() -> None:
1013
array = Array()
1114

1215
assert array.dtype == float32
@@ -16,7 +19,7 @@ def test_empty_array() -> None:
1619
assert len(array) == 0
1720

1821

19-
def test_empty_array_with_nonempty_dtype() -> None:
22+
def test_create_empty_array_with_nonempty_dtype() -> None:
2023
array = Array(dtype=int16)
2124

2225
assert array.dtype == int16
@@ -26,7 +29,32 @@ def test_empty_array_with_nonempty_dtype() -> None:
2629
assert len(array) == 0
2730

2831

29-
def test_empty_array_with_nonempty_shape() -> None:
32+
def test_create_empty_array_with_str_dtype() -> None:
33+
array = Array(dtype="short int")
34+
35+
assert array.dtype == int16
36+
assert array.ndim == 0
37+
assert array.size == 0
38+
assert array.shape == ()
39+
assert len(array) == 0
40+
41+
42+
def test_create_empty_array_with_literal_dtype() -> None:
43+
array = Array(dtype="h")
44+
45+
assert array.dtype == int16
46+
assert array.ndim == 0
47+
assert array.size == 0
48+
assert array.shape == ()
49+
assert len(array) == 0
50+
51+
52+
def test_create_empty_array_with_not_matching_str_dtype() -> None:
53+
with pytest.raises(TypeError):
54+
Array(dtype="hello world")
55+
56+
57+
def test_create_empty_array_with_nonempty_shape() -> None:
3058
array = Array(shape=(2, 3))
3159

3260
assert array.dtype == float32
@@ -36,7 +64,7 @@ def test_empty_array_with_nonempty_shape() -> None:
3664
assert len(array) == 2
3765

3866

39-
def test_array_from_1d_list() -> None:
67+
def test_create_array_from_1d_list() -> None:
4068
array = Array([1, 2, 3])
4169

4270
assert array.dtype == float32
@@ -46,11 +74,22 @@ def test_array_from_1d_list() -> None:
4674
assert len(array) == 3
4775

4876

49-
def test_array_from_2d_list() -> None:
77+
def test_create_array_from_2d_list() -> None:
5078
with pytest.raises(TypeError):
5179
Array([[1, 2, 3], [1, 2, 3]])
5280

5381

82+
def test_create_array_from_pyarray() -> None:
83+
py_array = pyarray.array("f", [1, 2, 3])
84+
array = Array(py_array)
85+
86+
assert array.dtype == float32
87+
assert array.ndim == 1
88+
assert array.size == 3
89+
assert array.shape == (3,)
90+
assert len(array) == 3
91+
92+
5493
def test_array_from_list_with_unsupported_dtype() -> None:
5594
for dtype in supported_dtypes:
5695
if dtype == float32:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy