mirror of
https://github.com/pybind/pybind11.git
synced 2026-05-13 09:46:10 +00:00
Add support for Eigen::Ref<...> function arguments
Eigen::Ref is a common way to pass eigen dense types without needing a template, e.g. the single definition `void func(Eigen::Ref<Eigen::MatrixXd> x)` can be called with any double matrix-like object. The current pybind11 eigen support fails with internal errors if attempting to bind a function with an Eigen::Ref<...> argument because Eigen::Ref<...> satisfies the "is_eigen_dense" requirement, but can't compile if actually used: Eigen::Ref<...> itself is not default constructible, and so the argument std::tuple containing an Eigen::Ref<...> isn't constructible, which results in compilation failure. This commit adds support for Eigen::Ref<...> by giving it its own type_caster implementation which consists of an internal type_caster of the referenced type, load/cast methods that dispatch to the internal type_caster, and a unique_ptr to an Eigen::Ref<> instance that gets set during load(). There is, of course, no performance advantage for pybind11-using code of using Eigen::Ref<...>--we are allocating a matrix of the derived type when loading it--but this has the advantage of allowing pybind11 to bind transparently to C++ methods taking Eigen::Refs.
This commit is contained in:
@@ -40,6 +40,19 @@ public:
|
||||
static constexpr bool value = decltype(test(std::declval<T>()))::value;
|
||||
};
|
||||
|
||||
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, which means we can't load
|
||||
// it (since there is no reference!), but we can cast from it.
|
||||
template <typename T> class is_eigen_ref {
|
||||
private:
|
||||
template<typename Derived> static typename std::enable_if<
|
||||
std::is_same<typename std::remove_const<T>::type, Eigen::Ref<Derived>>::value,
|
||||
Derived>::type test(const Eigen::Ref<Derived> &);
|
||||
static void test(...);
|
||||
public:
|
||||
typedef decltype(test(std::declval<T>())) Derived;
|
||||
static constexpr bool value = !std::is_void<Derived>::value;
|
||||
};
|
||||
|
||||
template <typename T> class is_eigen_sparse {
|
||||
private:
|
||||
template<typename Derived> static std::true_type test(const Eigen::SparseMatrixBase<Derived> &);
|
||||
@@ -49,7 +62,7 @@ public:
|
||||
};
|
||||
|
||||
template<typename Type>
|
||||
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::type> {
|
||||
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>::type> {
|
||||
typedef typename Type::Scalar Scalar;
|
||||
static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
|
||||
static constexpr bool isVector = Type::IsVectorAtCompileTime;
|
||||
@@ -149,6 +162,26 @@ protected:
|
||||
static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
|
||||
};
|
||||
|
||||
template<typename Type>
|
||||
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && is_eigen_ref<Type>::value>::type> {
|
||||
private:
|
||||
using Derived = typename std::remove_const<typename is_eigen_ref<Type>::Derived>::type;
|
||||
using DerivedCaster = type_caster<Derived>;
|
||||
DerivedCaster derived_caster;
|
||||
protected:
|
||||
std::unique_ptr<Type> value;
|
||||
public:
|
||||
bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
|
||||
static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
|
||||
static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }
|
||||
|
||||
static PYBIND11_DESCR name() { return DerivedCaster::name(); }
|
||||
|
||||
operator Type*() { return value.get(); }
|
||||
operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
|
||||
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
||||
};
|
||||
|
||||
template<typename Type>
|
||||
struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> {
|
||||
typedef typename Type::Scalar Scalar;
|
||||
|
||||
Reference in New Issue
Block a user