1
+ import array as pyarray
2
+
1
3
import pytest
2
4
3
5
from arrayfire .array_api import Array , float32 , int16
4
6
from arrayfire .array_api ._dtypes import supported_dtypes
5
7
6
8
# TODO change separated methods with setup and teardown to avoid code duplication
9
+ # TODO add tests for array arguments: device, offset, strides
7
10
8
11
9
- def test_empty_array () -> None :
12
+ def test_create_empty_array () -> None :
10
13
array = Array ()
11
14
12
15
assert array .dtype == float32
@@ -16,7 +19,7 @@ def test_empty_array() -> None:
16
19
assert len (array ) == 0
17
20
18
21
19
- def test_empty_array_with_nonempty_dtype () -> None :
22
+ def test_create_empty_array_with_nonempty_dtype () -> None :
20
23
array = Array (dtype = int16 )
21
24
22
25
assert array .dtype == int16
@@ -26,7 +29,32 @@ def test_empty_array_with_nonempty_dtype() -> None:
26
29
assert len (array ) == 0
27
30
28
31
29
- def test_empty_array_with_nonempty_shape () -> None :
32
+ def test_create_empty_array_with_str_dtype () -> None :
33
+ array = Array (dtype = "short int" )
34
+
35
+ assert array .dtype == int16
36
+ assert array .ndim == 0
37
+ assert array .size == 0
38
+ assert array .shape == ()
39
+ assert len (array ) == 0
40
+
41
+
42
+ def test_create_empty_array_with_literal_dtype () -> None :
43
+ array = Array (dtype = "h" )
44
+
45
+ assert array .dtype == int16
46
+ assert array .ndim == 0
47
+ assert array .size == 0
48
+ assert array .shape == ()
49
+ assert len (array ) == 0
50
+
51
+
52
+ def test_create_empty_array_with_not_matching_str_dtype () -> None :
53
+ with pytest .raises (TypeError ):
54
+ Array (dtype = "hello world" )
55
+
56
+
57
+ def test_create_empty_array_with_nonempty_shape () -> None :
30
58
array = Array (shape = (2 , 3 ))
31
59
32
60
assert array .dtype == float32
@@ -36,7 +64,7 @@ def test_empty_array_with_nonempty_shape() -> None:
36
64
assert len (array ) == 2
37
65
38
66
39
- def test_array_from_1d_list () -> None :
67
+ def test_create_array_from_1d_list () -> None :
40
68
array = Array ([1 , 2 , 3 ])
41
69
42
70
assert array .dtype == float32
@@ -46,11 +74,22 @@ def test_array_from_1d_list() -> None:
46
74
assert len (array ) == 3
47
75
48
76
49
- def test_array_from_2d_list () -> None :
77
+ def test_create_array_from_2d_list () -> None :
50
78
with pytest .raises (TypeError ):
51
79
Array ([[1 , 2 , 3 ], [1 , 2 , 3 ]])
52
80
53
81
82
+ def test_create_array_from_pyarray () -> None :
83
+ py_array = pyarray .array ("f" , [1 , 2 , 3 ])
84
+ array = Array (py_array )
85
+
86
+ assert array .dtype == float32
87
+ assert array .ndim == 1
88
+ assert array .size == 3
89
+ assert array .shape == (3 ,)
90
+ assert len (array ) == 3
91
+
92
+
54
93
def test_array_from_list_with_unsupported_dtype () -> None :
55
94
for dtype in supported_dtypes :
56
95
if dtype == float32 :
0 commit comments