Skip to content

Commit a73b40a

Browse files
authored
Merge pull request #243 from PerretB/view_type_caster
Casters for strided_views, array_adaptor, and tensor_adaptor
2 parents ff64d8f + fa97cf0 commit a73b40a

File tree

7 files changed

+193
-25
lines changed

7 files changed

+193
-25
lines changed

CMakeLists.txt

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,16 @@ message(STATUS "Found numpy: ${NUMPY_INCLUDE_DIRS}")
6262
# =====
6363

6464
set(XTENSOR_PYTHON_HEADERS
65-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray.hpp
66-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray_backstrides.hpp
67-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pycontainer.hpp
68-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pystrides_adaptor.hpp
69-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pytensor.hpp
70-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyvectorize.hpp
71-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_python_config.hpp
72-
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_type_caster_base.hpp
73-
)
65+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray.hpp
66+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyarray_backstrides.hpp
67+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pycontainer.hpp
68+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pynative_casters.hpp
69+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pystrides_adaptor.hpp
70+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pytensor.hpp
71+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyvectorize.hpp
72+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_python_config.hpp
73+
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_type_caster_base.hpp
74+
)
7475

7576
add_library(xtensor-python INTERFACE)
7677
target_include_directories(xtensor-python INTERFACE

include/xtensor-python/pyarray.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "pyarray_backstrides.hpp"
2222
#include "pycontainer.hpp"
2323
#include "pystrides_adaptor.hpp"
24+
#include "pynative_casters.hpp"
2425
#include "xtensor_type_caster_base.hpp"
2526

2627
namespace xt
@@ -91,11 +92,6 @@ namespace pybind11
9192
}
9293
};
9394

94-
// Type caster for casting xarray to ndarray
95-
template <class T, xt::layout_type L>
96-
struct type_caster<xt::xarray<T, L>> : xtensor_type_caster_base<xt::xarray<T, L>>
97-
{
98-
};
9995
}
10096
}
10197

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/***************************************************************************
2+
* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay *
3+
* Copyright (c) QuantStack *
4+
* *
5+
* Distributed under the terms of the BSD 3-Clause License. *
6+
* *
7+
* The full license is in the file LICENSE, distributed with this software. *
8+
****************************************************************************/
9+
10+
#ifndef PYNATIVE_CASTERS_HPP
11+
#define PYNATIVE_CASTERS_HPP
12+
13+
#include "xtensor_type_caster_base.hpp"
14+
15+
16+
namespace pybind11
17+
{
18+
namespace detail
19+
{
20+
// Type caster for casting xarray to ndarray
21+
template <class T, xt::layout_type L>
22+
struct type_caster<xt::xarray<T, L>> : xtensor_type_caster_base<xt::xarray<T, L>>
23+
{
24+
};
25+
26+
// Type caster for casting xt::xtensor to ndarray
27+
template <class T, std::size_t N, xt::layout_type L>
28+
struct type_caster<xt::xtensor<T, N, L>> : xtensor_type_caster_base<xt::xtensor<T, N, L>>
29+
{
30+
};
31+
32+
// Type caster for casting xt::xstrided_view to ndarray
33+
template <class CT, class S, xt::layout_type L, class FST>
34+
struct type_caster<xt::xstrided_view<CT, S, L, FST>> : xtensor_type_caster_base<xt::xstrided_view<CT, S, L, FST>>
35+
{
36+
};
37+
38+
// Type caster for casting xt::xarray_adaptor to ndarray
39+
template <class EC, xt::layout_type L, class SC, class Tag>
40+
struct type_caster<xt::xarray_adaptor<EC, L, SC, Tag>> : xtensor_type_caster_base<xt::xarray_adaptor<EC, L, SC, Tag>>
41+
{
42+
};
43+
44+
// Type caster for casting xt::xtensor_adaptor to ndarray
45+
template <class EC, std::size_t N, xt::layout_type L, class Tag>
46+
struct type_caster<xt::xtensor_adaptor<EC, N, L, Tag>> : xtensor_type_caster_base<xt::xtensor_adaptor<EC, N, L, Tag>>
47+
{
48+
};
49+
}
50+
}
51+
52+
#endif

include/xtensor-python/pytensor.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "pycontainer.hpp"
2323
#include "pystrides_adaptor.hpp"
24+
#include "pynative_casters.hpp"
2425
#include "xtensor_type_caster_base.hpp"
2526

2627
namespace xt
@@ -99,11 +100,6 @@ namespace pybind11
99100
}
100101
};
101102

102-
// Type caster for casting xt::xtensor to ndarray
103-
template <class T, std::size_t N, xt::layout_type L>
104-
struct type_caster<xt::xtensor<T, N, L>> : xtensor_type_caster_base<xt::xtensor<T, N, L>>
105-
{
106-
};
107103
}
108104
}
109105

include/xtensor-python/xtensor_type_caster_base.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace pybind11
2323
{
2424
namespace detail
2525
{
26-
// Casts an xtensor (or xarray) type to numpy array.If given a base,
26+
// Casts a strided expression type to numpy array.If given a base,
2727
// the numpy array references the src data, otherwise it'll make a copy.
2828
// The writeable attributes lets you specify writeable flag for the array.
2929
template <typename Type>
@@ -39,7 +39,7 @@ namespace pybind11
3939
std::vector<std::size_t> python_shape(src.shape().size());
4040
std::copy(src.shape().begin(), src.shape().end(), python_shape.begin());
4141

42-
array a(python_shape, python_strides, src.begin(), base);
42+
array a(python_shape, python_strides, &*(src.begin()), base);
4343

4444
if (!writeable)
4545
{
@@ -49,8 +49,8 @@ namespace pybind11
4949
return a.release();
5050
}
5151

52-
// Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
53-
// reference the xtensor object's data with `base` as the python-registered base class (if omitted,
52+
// Takes an lvalue ref to some strided expression type and a (python) base object, creating a numpy array that
53+
// reference the expression object's data with `base` as the python-registered base class (if omitted,
5454
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
5555
// non-writeable if the given type is const.
5656
template <typename Type, typename CType>
@@ -59,7 +59,7 @@ namespace pybind11
5959
return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
6060
}
6161

62-
// Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
62+
// Takes a pointer to a strided expression, builds a capsule around it, then returns a numpy
6363
// array that references the encapsulated data with a python-side reference to the capsule to tie
6464
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
6565
// not the CType of the pointer given is const.
@@ -70,7 +70,7 @@ namespace pybind11
7070
return xtensor_ref_array<Type>(*src, base);
7171
}
7272

73-
// Base class of type_caster for xtensor and xarray
73+
// Base class of type_caster for strided expressions
7474
template <class Type>
7575
struct xtensor_type_caster_base
7676
{

test_python/main.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "xtensor-python/pyarray.hpp"
1616
#include "xtensor-python/pytensor.hpp"
1717
#include "xtensor-python/pyvectorize.hpp"
18+
#include "xtensor/xadapt.hpp"
19+
#include "xtensor/xstrided_view.hpp"
1820

1921
namespace py = pybind11;
2022
using complex_t = std::complex<double>;
@@ -133,6 +135,49 @@ class C
133135
array_type m_array;
134136
};
135137

138+
struct test_native_casters
139+
{
140+
using array_type = xt::xarray<double>;
141+
array_type a = xt::ones<double>({50, 50});
142+
143+
const auto & get_array()
144+
{
145+
return a;
146+
}
147+
148+
auto get_strided_view()
149+
{
150+
return xt::strided_view(a, {xt::range(0, 1), xt::range(0, 3, 2)});
151+
}
152+
153+
auto get_array_adapter()
154+
{
155+
using shape_type = std::vector<size_t>;
156+
shape_type shape = {2, 2};
157+
shape_type stride = {3, 2};
158+
return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride);
159+
}
160+
161+
auto get_tensor_adapter()
162+
{
163+
using shape_type = std::array<size_t, 2>;
164+
shape_type shape = {2, 2};
165+
shape_type stride = {3, 2};
166+
return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride);
167+
}
168+
169+
auto get_owning_array_adapter()
170+
{
171+
size_t size = 100;
172+
int * data = new int[size];
173+
std::fill(data, data + size, 1);
174+
175+
using shape_type = std::vector<size_t>;
176+
shape_type shape = {size};
177+
return xt::adapt(std::move(data), size, xt::acquire_ownership(), shape);
178+
}
179+
};
180+
136181
xt::pyarray<A> dtype_to_python()
137182
{
138183
A a1{123, 321, 'a', {1, 2, 3}};
@@ -257,4 +302,15 @@ PYBIND11_MODULE(xtensor_python_test, m)
257302

258303
m.def("diff_shape_overload", [](xt::pytensor<int, 1> a) { return 1; });
259304
m.def("diff_shape_overload", [](xt::pytensor<int, 2> a) { return 2; });
305+
306+
py::class_<test_native_casters>(m, "test_native_casters")
307+
.def(py::init<>())
308+
.def("get_array", &test_native_casters::get_array, py::return_value_policy::reference_internal) // memory managed by the class instance
309+
.def("get_strided_view", &test_native_casters::get_strided_view, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned view
310+
.def("get_array_adapter", &test_native_casters::get_array_adapter, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned adapter
311+
.def("get_tensor_adapter", &test_native_casters::get_tensor_adapter, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned adapter
312+
.def("get_owning_array_adapter", &test_native_casters::get_owning_array_adapter) // auto memory management as the adapter owns its memory
313+
.def("view_keep_alive_member_function", [](test_native_casters & self, xt::pyarray<double> & a) // keep_alive<0, 2>() => do not free second parameter before the returned view
314+
{return xt::reshape_view(a, {a.size(), });},
315+
py::keep_alive<0, 2>());
260316
}

test_python/test_pyarray.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,73 @@ def test_diff_shape_overload(self):
166166
# FIXME: the TypeError information is not informative
167167
xt.diff_shape_overload(np.ones((2, 2, 2)))
168168

169+
def test_native_casters(self):
170+
import gc
171+
172+
# check keep alive policy for get_strided_view()
173+
gc.collect()
174+
obj = xt.test_native_casters()
175+
a = obj.get_strided_view()
176+
obj = None
177+
gc.collect()
178+
_ = np.zeros((100, 100))
179+
self.assertEqual(a.sum(), a.size)
180+
181+
# check keep alive policy for get_array_adapter()
182+
gc.collect()
183+
obj = xt.test_native_casters()
184+
a = obj.get_array_adapter()
185+
obj = None
186+
gc.collect()
187+
_ = np.zeros((100, 100))
188+
self.assertEqual(a.sum(), a.size)
189+
190+
# check keep alive policy for get_array_adapter()
191+
gc.collect()
192+
obj = xt.test_native_casters()
193+
a = obj.get_tensor_adapter()
194+
obj = None
195+
gc.collect()
196+
_ = np.zeros((100, 100))
197+
self.assertEqual(a.sum(), a.size)
198+
199+
# check keep alive policy for get_owning_array_adapter()
200+
gc.collect()
201+
obj = xt.test_native_casters()
202+
a = obj.get_owning_array_adapter()
203+
gc.collect()
204+
_ = np.zeros((100, 100))
205+
self.assertEqual(a.sum(), a.size)
206+
207+
# check keep alive policy for view_keep_alive_member_function()
208+
gc.collect()
209+
a = np.ones((100, 100))
210+
b = obj.view_keep_alive_member_function(a)
211+
obj = None
212+
a = None
213+
gc.collect()
214+
_ = np.zeros((100, 100))
215+
self.assertEqual(b.sum(), b.size)
216+
217+
# check shared buffer (insure that no copy is done)
218+
obj = xt.test_native_casters()
219+
arr = obj.get_array()
220+
221+
strided_view = obj.get_strided_view()
222+
strided_view[0, 1] = -1
223+
self.assertEqual(strided_view.shape, (1, 2))
224+
self.assertEqual(arr[0, 2], -1)
225+
226+
adapter = obj.get_array_adapter()
227+
self.assertEqual(adapter.shape, (2, 2))
228+
adapter[1, 1] = -2
229+
self.assertEqual(arr[0, 5], -2)
230+
231+
adapter = obj.get_tensor_adapter()
232+
self.assertEqual(adapter.shape, (2, 2))
233+
adapter[1, 1] = -3
234+
self.assertEqual(arr[0, 5], -3)
235+
169236

170237
class AttributeTest(TestCase):
171238

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy