Skip to content

Commit ab37736

Browse files
committed
add get_2darray for templated matrices and spy plot
1 parent 4ba2bf4 commit ab37736

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ eigen_include = -I /usr/local/include/eigen3
2020
example_execs = minimal modern basic animation nonblock xkcd quiver bar surface subplot fill_inbetween fill update
2121

2222
# Executable names for examples using Eigen
23-
eigen_execs = eigen loglog semilogx semilogy small
23+
eigen_execs = eigen loglog semilogx semilogy small spy
2424

2525
# Example targets (default if just 'make' is called)
2626
examples: $(example_execs)

examples/spy.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include <vector>
2+
#include <Eigen/Dense>
3+
#include "../matplotlibcpp.h"
4+
namespace plt = matplotlibcpp;
5+
6+
int main() {
7+
8+
const unsigned n = 100;
9+
Eigen::MatrixXd A(n / 2, n);
10+
std::vector<std::vector<double>> B;
11+
12+
for (unsigned i = 0; i < n / 2; ++i) {
13+
A(i, i) = 1;
14+
std::vector<double> row(n);
15+
row[i] = 1;
16+
17+
if (i < n / 2) {
18+
A(i, i + n / 2) = 1;
19+
row[i + n / 2] = 1;
20+
}
21+
B.push_back(row);
22+
}
23+
24+
for (unsigned i = 0; i < n / 2; ++i) {
25+
for (unsigned j = 0; j < n; ++j) {
26+
if (A(i, j) != B[i][j]) {
27+
std::cout << i << "," << j << " differ!\n";
28+
}
29+
}
30+
}
31+
32+
plt::figure();
33+
plt::title("Eigen");
34+
plt::spy(A);
35+
36+
plt::figure();
37+
plt::title("vector");
38+
plt::spy(B);
39+
plt::show();
40+
return 0;
41+
}

matplotlibcpp.h

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct _interpreter {
5959
PyObject *s_python_function_fill_between;
6060
PyObject *s_python_function_hist;
6161
PyObject *s_python_function_scatter;
62+
PyObject *s_python_function_spy;
6263
PyObject *s_python_function_subplot;
6364
PyObject *s_python_function_legend;
6465
PyObject *s_python_function_xlim;
@@ -188,6 +189,7 @@ struct _interpreter {
188189
PyObject_GetAttrString(pymod, "fill_between");
189190
s_python_function_hist = PyObject_GetAttrString(pymod, "hist");
190191
s_python_function_scatter = PyObject_GetAttrString(pymod, "scatter");
192+
s_python_function_spy = PyObject_GetAttrString(pymod, "spy");
191193
s_python_function_subplot = PyObject_GetAttrString(pymod, "subplot");
192194
s_python_function_legend = PyObject_GetAttrString(pymod, "legend");
193195
s_python_function_ylim = PyObject_GetAttrString(pymod, "ylim");
@@ -236,7 +238,8 @@ struct _interpreter {
236238
!s_python_function_errorbar || !s_python_function_tight_layout ||
237239
!s_python_function_stem || !s_python_function_xkcd ||
238240
!s_python_function_text || !s_python_function_suptitle ||
239-
!s_python_function_bar || !s_python_function_subplots_adjust) {
241+
!s_python_function_bar || !s_python_function_subplots_adjust ||
242+
!s_python_function_spy) {
240243
throw std::runtime_error("Couldn't find required function!");
241244
}
242245

@@ -253,6 +256,7 @@ struct _interpreter {
253256
!PyFunction_Check(s_python_function_loglog) ||
254257
!PyFunction_Check(s_python_function_fill) ||
255258
!PyFunction_Check(s_python_function_fill_between) ||
259+
!PyFunction_Check(s_python_function_spy) ||
256260
!PyFunction_Check(s_python_function_subplot) ||
257261
!PyFunction_Check(s_python_function_legend) ||
258262
!PyFunction_Check(s_python_function_annotate) ||
@@ -396,7 +400,7 @@ PyObject *get_2darray(const std::vector<::std::vector<Numeric>> &v) {
396400
detail::_interpreter::get(); // interpreter needs to be initialized for the
397401
// numpy commands to work
398402
if (v.size() < 1)
399-
throw std::runtime_error("get_2d_array v too small");
403+
throw std::runtime_error("get_2darray v too small");
400404

401405
npy_intp vsize[2] = {static_cast<npy_intp>(v.size()),
402406
static_cast<npy_intp>(v[0].size())};
@@ -408,14 +412,39 @@ PyObject *get_2darray(const std::vector<::std::vector<Numeric>> &v) {
408412

409413
for (const ::std::vector<Numeric> &v_row : v) {
410414
if (v_row.size() != static_cast<size_t>(vsize[1]))
411-
throw std::runtime_error("Missmatched array size");
415+
throw std::runtime_error("mismatched array size");
412416
std::copy(v_row.begin(), v_row.end(), vd_begin);
413417
vd_begin += vsize[1];
414418
}
415419

416420
return reinterpret_cast<PyObject *>(varray);
417421
}
418422

423+
// suitable for Eigen matrices
424+
template <typename Matrix>
425+
PyObject *get_2darray(const Matrix &A) {
426+
detail::_interpreter::get(); // interpreter needs to be initialized for the
427+
// numpy commands to work
428+
if (A.size() < 1)
429+
throw std::runtime_error("get_2darray A too small");
430+
431+
npy_intp vsize[2] = {static_cast<npy_intp>(A.rows()),
432+
static_cast<npy_intp>(A.cols())};
433+
434+
PyArrayObject *varray =
435+
(PyArrayObject *)PyArray_SimpleNew(2, vsize, NPY_DOUBLE);
436+
437+
double *vd_begin = static_cast<double *>(PyArray_DATA(varray));
438+
439+
for (std::size_t i = 0; i < A.rows(); ++i) {
440+
for (std::size_t j = 0; j < A.cols(); ++j) {
441+
*(vd_begin + i * A.cols() + j) = A(i, j);
442+
}
443+
}
444+
445+
return reinterpret_cast<PyObject *>(varray);
446+
}
447+
419448
#else // fallback if we don't have numpy: copy every element of the given vector
420449

421450
template <typename Vector> PyObject *get_array(const Vector &v) {
@@ -869,6 +898,29 @@ bool scatter(const VectorX &x, const VectorY &y, const double s = 1.0) {
869898
return res;
870899
}
871900

901+
// @brief Spy plot
902+
// @param A the matrix
903+
template <typename Matrix>
904+
bool spy(const Matrix &A, double precision=0) {
905+
PyObject *Aarray = get_2darray(A);
906+
907+
PyObject *kwargs = PyDict_New();
908+
PyDict_SetItemString(kwargs, "precision", PyFloat_FromDouble(precision));
909+
910+
PyObject *plot_args = PyTuple_New(1);
911+
PyTuple_SetItem(plot_args, 0, Aarray);
912+
913+
PyObject *res = PyObject_Call(
914+
detail::_interpreter::get().s_python_function_spy, plot_args, kwargs);
915+
916+
Py_DECREF(plot_args);
917+
Py_DECREF(kwargs);
918+
if (res)
919+
Py_DECREF(res);
920+
921+
return res;
922+
}
923+
872924
template <typename Numeric>
873925
bool bar(const std::vector<Numeric> &y, std::string ec = "black",
874926
std::string ls = "-", double lw = 1.0,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy