Skip to content

Commit 1e922c5

Browse files
authored
Merge pull request #251 from adriendelsalle/pyarray-init-list
`pyarray` initializers lists work with all layouts
2 parents 062c8c2 + 66b81ae commit 1e922c5

File tree

2 files changed

+53
-16
lines changed

2 files changed

+53
-16
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ namespace xt
224224
storage_type& storage_impl() noexcept;
225225
const storage_type& storage_impl() const noexcept;
226226

227+
layout_type default_dynamic_layout();
228+
227229
friend class xcontainer<pyarray<T, L>>;
228230
friend class pycontainer<pyarray<T, L>>;
229231
};
@@ -254,48 +256,48 @@ namespace xt
254256
inline pyarray<T, L>::pyarray(const value_type& t)
255257
: base_type()
256258
{
257-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
259+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
258260
nested_copy(m_storage.begin(), t);
259261
}
260262

261263
template <class T, layout_type L>
262264
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 1> t)
263265
: base_type()
264266
{
265-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
266-
nested_copy(m_storage.begin(), t);
267+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
268+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
267269
}
268270

269271
template <class T, layout_type L>
270272
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 2> t)
271273
: base_type()
272274
{
273-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
274-
nested_copy(m_storage.begin(), t);
275+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
276+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
275277
}
276278

277279
template <class T, layout_type L>
278280
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 3> t)
279281
: base_type()
280282
{
281-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
282-
nested_copy(m_storage.begin(), t);
283+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
284+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
283285
}
284286

285287
template <class T, layout_type L>
286288
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 4> t)
287289
: base_type()
288290
{
289-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
290-
nested_copy(m_storage.begin(), t);
291+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
292+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
291293
}
292294

293295
template <class T, layout_type L>
294296
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 5> t)
295297
: base_type()
296298
{
297-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
298-
nested_copy(m_storage.begin(), t);
299+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
300+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
299301
}
300302

301303
template <class T, layout_type L>
@@ -443,7 +445,9 @@ namespace xt
443445
// TODO: prevent intermediary shape allocation
444446
shape_type shape = xtl::forward_sequence<shape_type, decltype(e.derived_cast().shape())>(e.derived_cast().shape());
445447
strides_type strides = xtl::make_sequence<strides_type>(shape.size(), size_type(0));
446-
compute_strides(shape, L, strides);
448+
layout_type layout = default_dynamic_layout();
449+
450+
compute_strides(shape, layout, strides);
447451
init_array(shape, strides);
448452
semantic_base::assign(e);
449453
}
@@ -559,6 +563,12 @@ namespace xt
559563
{
560564
return m_storage;
561565
}
566+
567+
template <class T, layout_type L>
568+
layout_type pyarray<T, L>::default_dynamic_layout()
569+
{
570+
return L == layout_type::dynamic ? layout_type::row_major : L;
571+
}
562572
}
563573

564574
#endif

test/test_pyarray.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,44 @@ namespace xt
3737

3838
TEST(pyarray, initializer_constructor)
3939
{
40-
pyarray<int> t
40+
pyarray<int> r
4141
{{{ 0, 1, 2},
4242
{ 3, 4, 5},
4343
{ 6, 7, 8}},
4444
{{ 9, 10, 11},
4545
{12, 13, 14},
4646
{15, 16, 17}}};
4747

48-
EXPECT_EQ(t.dimension(), 3);
49-
EXPECT_EQ(t(0, 0, 1), 1);
50-
EXPECT_EQ(t.shape()[0], 2);
48+
EXPECT_EQ(r.layout(), xt::layout_type::row_major);
49+
EXPECT_EQ(r.dimension(), 3);
50+
EXPECT_EQ(r(0, 0, 1), 1);
51+
EXPECT_EQ(r.shape()[0], 2);
52+
53+
pyarray<int, xt::layout_type::column_major> c
54+
{{{ 0, 1, 2},
55+
{ 3, 4, 5},
56+
{ 6, 7, 8}},
57+
{{ 9, 10, 11},
58+
{12, 13, 14},
59+
{15, 16, 17}}};
60+
61+
EXPECT_EQ(c.layout(), xt::layout_type::column_major);
62+
EXPECT_EQ(c.dimension(), 3);
63+
EXPECT_EQ(c(0, 0, 1), 1);
64+
EXPECT_EQ(c.shape()[0], 2);
65+
66+
pyarray<int, xt::layout_type::dynamic> d
67+
{{{ 0, 1, 2},
68+
{ 3, 4, 5},
69+
{ 6, 7, 8}},
70+
{{ 9, 10, 11},
71+
{12, 13, 14},
72+
{15, 16, 17}}};
73+
74+
EXPECT_EQ(d.layout(), xt::layout_type::row_major);
75+
EXPECT_EQ(d.dimension(), 3);
76+
EXPECT_EQ(d(0, 0, 1), 1);
77+
EXPECT_EQ(d.shape()[0], 2);
5178
}
5279

5380
TEST(pyarray, shaped_constructor)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy