Skip to content

Commit 1b97e58

Browse files
committed
dump
1 parent e38ce34 commit 1b97e58

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from . import xp
2+
3+
4+
def test_array_namespace_info():
5+
assert hasattr(xp, "__array_namespace_info__")
6+
# TODO: test output

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,16 @@ def test_ceil(x):
933933
unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True)
934934

935935

936+
@pytest.mark.min_version("2023.12")
937+
@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
938+
def test_clip(x):
939+
# TODO: test min/max kwargs, adjust values testing accordingly
940+
out = xp.clip(x)
941+
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
942+
ph.assert_shape("clip", out_shape=out.shape, expected=x.shape)
943+
ph.assert_array_elements("clip", out=out, expected=x)
944+
945+
936946
if api_version >= "2022.12":
937947

938948
@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
@@ -943,6 +953,15 @@ def test_conj(x):
943953
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
944954

945955

956+
@pytest.mark.min_version("2023.12")
957+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
958+
def test_copysign(x1, x2):
959+
out = xp.copysign(x1, x2)
960+
ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
961+
ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
962+
# TODO: values testing
963+
964+
946965
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
947966
def test_cos(x):
948967
out = xp.cos(x)
@@ -1095,6 +1114,15 @@ def test_greater_equal(ctx, data):
10951114
)
10961115

10971116

1117+
@pytest.mark.min_version("2023.12")
1118+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
1119+
def test_hypot(x1, x2):
1120+
out = xp.hypot(x1, x2)
1121+
ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1122+
ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1123+
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)
1124+
1125+
10981126
if api_version >= "2022.12":
10991127

11001128
@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
@@ -1261,6 +1289,24 @@ def test_logical_xor(x1, x2):
12611289
)
12621290

12631291

1292+
@pytest.mark.min_version("2023.12")
1293+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
1294+
def test_maximum(x1, x2):
1295+
out = xp.maximum(x1, x2)
1296+
ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1297+
ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1298+
binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True)
1299+
1300+
1301+
@pytest.mark.min_version("2023.12")
1302+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
1303+
def test_minimum(x1, x2):
1304+
out = xp.minimum(x1, x2)
1305+
ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1306+
ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1307+
binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True)
1308+
1309+
12641310
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
12651311
@given(data=st.data())
12661312
def test_multiply(ctx, data):

array_api_tests/test_statistical_functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616
from .typing import DataType
1717

1818

19+
@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_dims=1, max_dims=1)))
20+
def test_cumulative_sum(x):
21+
# TODO: test kwargs + diff shapes, adjust shape and values testing accordingly
22+
out = xp.cumulative_sum(x)
23+
# TODO: assert dtype
24+
ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=x.shape)
25+
# TODO: assert values
26+
27+
1928
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
2029
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
2130
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy