Add support for Eigen::Ref<...> function arguments

Eigen::Ref is a common way to pass eigen dense types without needing a
template, e.g. the single definition `void
func(Eigen::Ref<Eigen::MatrixXd> x)` can be called with any double
matrix-like object.

The current pybind11 eigen support fails with internal errors if
attempting to bind a function with an Eigen::Ref<...> argument because
Eigen::Ref<...> satisfies the "is_eigen_dense" requirement, but can't
compile if actually used: Eigen::Ref<...> itself is not default
constructible, and so the argument std::tuple containing an
Eigen::Ref<...> isn't constructible, which results in compilation
failure.

This commit adds support for Eigen::Ref<...> by giving it its own
type_caster implementation which consists of an internal type_caster of
the referenced type, load/cast methods that dispatch to the internal
type_caster, and a unique_ptr to an Eigen::Ref<> instance that gets
set during load().

There is, of course, no performance advantage for pybind11-using code of
using Eigen::Ref<...>--we are allocating a matrix of the derived type
when loading it--but this has the advantage of allowing pybind11 to bind
transparently to C++ methods taking Eigen::Refs.
This commit is contained in:
Jason Rhinelander
2016-08-03 16:50:22 -04:00
parent 7f9603fe24
commit 5fd5074a0b
4 changed files with 63 additions and 1 deletions

View File

@@ -40,6 +40,19 @@ public:
static constexpr bool value = decltype(test(std::declval<T>()))::value;
};
// Eigen::Ref<Derived> satisfies is_eigen_dense, but isn't constructible, which means we can't load
// it (since there is no reference!), but we can cast from it.
template <typename T> class is_eigen_ref {
private:
template<typename Derived> static typename std::enable_if<
std::is_same<typename std::remove_const<T>::type, Eigen::Ref<Derived>>::value,
Derived>::type test(const Eigen::Ref<Derived> &);
static void test(...);
public:
typedef decltype(test(std::declval<T>())) Derived;
static constexpr bool value = !std::is_void<Derived>::value;
};
template <typename T> class is_eigen_sparse {
private:
template<typename Derived> static std::true_type test(const Eigen::SparseMatrixBase<Derived> &);
@@ -49,7 +62,7 @@ public:
};
template<typename Type>
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value>::type> {
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && !is_eigen_ref<Type>::value>::type> {
typedef typename Type::Scalar Scalar;
static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit;
static constexpr bool isVector = Type::IsVectorAtCompileTime;
@@ -149,6 +162,26 @@ protected:
static PYBIND11_DESCR cols() { return _<T::ColsAtCompileTime>(); }
};
template<typename Type>
struct type_caster<Type, typename std::enable_if<is_eigen_dense<Type>::value && is_eigen_ref<Type>::value>::type> {
private:
using Derived = typename std::remove_const<typename is_eigen_ref<Type>::Derived>::type;
using DerivedCaster = type_caster<Derived>;
DerivedCaster derived_caster;
protected:
std::unique_ptr<Type> value;
public:
bool load(handle src, bool convert) { if (derived_caster.load(src, convert)) { value.reset(new Type(derived_caster.operator Derived&())); return true; } return false; }
static handle cast(const Type &src, return_value_policy policy, handle parent) { return DerivedCaster::cast(src, policy, parent); }
static handle cast(const Type *src, return_value_policy policy, handle parent) { return DerivedCaster::cast(*src, policy, parent); }
static PYBIND11_DESCR name() { return DerivedCaster::name(); }
operator Type*() { return value.get(); }
operator Type&() { if (!value) pybind11_fail("Eigen::Ref<...> value not loaded"); return *value; }
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
};
template<typename Type>
struct type_caster<Type, typename std::enable_if<is_eigen_sparse<Type>::value>::type> {
typedef typename Type::Scalar Scalar;