mirror of
https://github.com/pybind/pybind11.git
synced 2026-05-11 17:00:34 +00:00
Fix eigen copying of non-standard stride values
Some Eigen objects, such as those returned by matrix.diagonal() and matrix.block() have non-standard stride values because they are basically just maps onto the underlying matrix without copying it (for example, the primary diagonal of a 3x3 matrix is a vector-like object with .src equal to the full matrix data, but with stride 4). Returning such an object from a pybind11 method breaks, however, because pybind11 assumes vectors have stride 1, and that matrices have strides equal to the number of rows/columns or 1 (depending on whether the matrix is stored column-major or row-major). This commit fixes the issue by making pybind11 use Eigen's stride methods when copying the data.
This commit is contained in:
@@ -12,6 +12,8 @@ from example import sparse_passthrough_r, sparse_passthrough_c
|
||||
from example import double_row, double_col
|
||||
from example import double_mat_cm, double_mat_rm
|
||||
from example import cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6
|
||||
from example import diagonal, diagonal_1, diagonal_n
|
||||
from example import block
|
||||
try:
|
||||
import numpy as np
|
||||
import scipy
|
||||
@@ -78,3 +80,11 @@ for chol in [cholesky1, cholesky2, cholesky3, cholesky4, cholesky5, cholesky6]:
|
||||
print("cholesky" + str(i) + " " + ("OK" if (mymat == np.array([[1,0,0], [2,3,0], [4,5,6]])).all() else "NOT OKAY"))
|
||||
i += 1
|
||||
|
||||
print("diagonal() %s" % ("OK" if (diagonal(ref) == ref.diagonal()).all() else "FAILED"))
|
||||
print("diagonal_1() %s" % ("OK" if (diagonal_1(ref) == ref.diagonal(1)).all() else "FAILED"))
|
||||
for i in range(-5, 7):
|
||||
print("diagonal_n(%d) %s" % (i, "OK" if (diagonal_n(ref, i) == ref.diagonal(i)).all() else "FAILED"))
|
||||
|
||||
print("block(2,1,3,3) %s" % ("OK" if (block(ref, 2, 1, 3, 3) == ref[2:5, 1:4]).all() else "FAILED"))
|
||||
print("block(1,4,4,2) %s" % ("OK" if (block(ref, 1, 4, 4, 2) == ref[1:, 4:]).all() else "FAILED"))
|
||||
print("block(1,4,3,2) %s" % ("OK" if (block(ref, 1, 4, 3, 2) == ref[1:4, 4:]).all() else "FAILED"))
|
||||
|
||||
Reference in New Issue
Block a user