Skip to content

Commit 4187b27

Browse files
committed
Add tests for arithmetic operators
1 parent 5a29ffa commit 4187b27

File tree

2 files changed

+111
-6
lines changed

2 files changed

+111
-6
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,35 +159,35 @@ def __add__(self, other: int | float | Array, /) -> Array:
159159
"""
160160
return _process_c_function(self, other, backend.get().af_add)
161161

162-
def __sub__(self, other: int | float | bool | complex | Array, /) -> Array:
162+
def __sub__(self, other: int | float | Array, /) -> Array:
163163
"""
164164
Return self - other.
165165
"""
166166
return _process_c_function(self, other, backend.get().af_sub)
167167

168-
def __mul__(self, other: int | float | bool | complex | Array, /) -> Array:
168+
def __mul__(self, other: int | float | Array, /) -> Array:
169169
"""
170170
Return self * other.
171171
"""
172172
return _process_c_function(self, other, backend.get().af_mul)
173173

174-
def __truediv__(self, other: int | float | bool | complex | Array, /) -> Array:
174+
def __truediv__(self, other: int | float | Array, /) -> Array:
175175
"""
176176
Return self / other.
177177
"""
178178
return _process_c_function(self, other, backend.get().af_div)
179179

180-
def __floordiv__(self, other: int | float | bool | complex | Array, /) -> Array:
180+
def __floordiv__(self, other: int | float | Array, /) -> Array:
181181
# TODO
182182
return NotImplemented
183183

184-
def __mod__(self, other: int | float | bool | complex | Array, /) -> Array:
184+
def __mod__(self, other: int | float | Array, /) -> Array:
185185
"""
186186
Return self % other.
187187
"""
188188
return _process_c_function(self, other, backend.get().af_mod)
189189

190-
def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
190+
def __pow__(self, other: int | float | Array, /) -> Array:
191191
"""
192192
Return self ** other.
193193
"""

arrayfire/array_api/tests/test_array_object.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,108 @@ def test_array_sum() -> None:
111111
assert res[0].scalar() == 2
112112
assert res[1].scalar() == 3
113113
assert res[2].scalar() == 4
114+
115+
res = array + 1.5
116+
assert res[0].scalar() == 2.5
117+
assert res[1].scalar() == 3.5
118+
assert res[2].scalar() == 4.5
119+
120+
res = array + Array([9, 9, 9])
121+
assert res[0].scalar() == 10
122+
assert res[1].scalar() == 11
123+
assert res[2].scalar() == 12
124+
125+
126+
def test_array_sub() -> None:
127+
array = Array([1, 2, 3])
128+
res = array - 1
129+
assert res[0].scalar() == 0
130+
assert res[1].scalar() == 1
131+
assert res[2].scalar() == 2
132+
133+
res = array - 1.5
134+
assert res[0].scalar() == -0.5
135+
assert res[1].scalar() == 0.5
136+
assert res[2].scalar() == 1.5
137+
138+
res = array - Array([9, 9, 9])
139+
assert res[0].scalar() == -8
140+
assert res[1].scalar() == -7
141+
assert res[2].scalar() == -6
142+
143+
144+
def test_array_mul() -> None:
145+
array = Array([1, 2, 3])
146+
res = array * 2
147+
assert res[0].scalar() == 2
148+
assert res[1].scalar() == 4
149+
assert res[2].scalar() == 6
150+
151+
res = array * 1.5
152+
assert res[0].scalar() == 1.5
153+
assert res[1].scalar() == 3
154+
assert res[2].scalar() == 4.5
155+
156+
res = array * Array([9, 9, 9])
157+
assert res[0].scalar() == 9
158+
assert res[1].scalar() == 18
159+
assert res[2].scalar() == 27
160+
161+
162+
def test_array_truediv() -> None:
163+
array = Array([1, 2, 3])
164+
res = array / 2
165+
assert res[0].scalar() == 0.5
166+
assert res[1].scalar() == 1
167+
assert res[2].scalar() == 1.5
168+
169+
res = array / 1.5
170+
assert round(res[0].scalar(), 5) == 0.66667 # type: ignore[arg-type]
171+
assert round(res[1].scalar(), 5) == 1.33333 # type: ignore[arg-type]
172+
assert res[2].scalar() == 2
173+
174+
res = array / Array([2, 2, 2])
175+
assert res[0].scalar() == 0.5
176+
assert res[1].scalar() == 1
177+
assert res[2].scalar() == 1.5
178+
179+
180+
def test_array_floordiv() -> None:
181+
# TODO add test after implementation of __floordiv__
182+
pass
183+
184+
185+
def test_array_mod() -> None:
186+
array = Array([1, 2, 3])
187+
res = array % 2
188+
assert res[0].scalar() == 1
189+
assert res[1].scalar() == 0
190+
assert res[2].scalar() == 1
191+
192+
res = array % 1.5
193+
assert res[0].scalar() == 1.0
194+
assert res[1].scalar() == 0.5
195+
assert res[2].scalar() == 0.0
196+
197+
res = array % Array([9, 9, 9])
198+
assert res[0].scalar() == 1.0
199+
assert res[1].scalar() == 2.0
200+
assert res[2].scalar() == 3.0
201+
202+
203+
def test_array_pow() -> None:
204+
array = Array([1, 2, 3])
205+
res = array ** 2
206+
assert res[0].scalar() == 1
207+
assert res[1].scalar() == 4
208+
assert res[2].scalar() == 9
209+
210+
res = array ** 1.5
211+
assert res[0].scalar() == 1
212+
assert round(res[1].scalar(), 5) == 2.82843 # type: ignore[arg-type]
213+
assert round(res[2].scalar(), 5) == 5.19615 # type: ignore[arg-type]
214+
215+
res = array ** Array([9, 9, 9])
216+
assert res[0].scalar() == 1
217+
assert res[1].scalar() == 512
218+
assert res[2].scalar() == 19683

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy