mirror of
https://github.com/pybind/pybind11.git
synced 2026-05-24 14:55:01 +00:00
Support arrays inside PYBIND11_NUMPY_DTYPE (#832)
Resolves #800. Both C++ arrays and std::array are supported, including mixtures like std::array<int, 2>[4]. In a multi-dimensional array of char, the last dimension is used to construct a numpy string type.
This commit is contained in:
committed by
Dean Moldovan
parent
78f1dcf98f
commit
8e0d832c7d
@@ -246,6 +246,46 @@ template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::tru
|
||||
template <typename T> struct is_complex : std::false_type { };
|
||||
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
|
||||
|
||||
template <typename T> struct array_info_scalar {
|
||||
typedef T type;
|
||||
static constexpr bool is_array = false;
|
||||
static constexpr bool is_empty = false;
|
||||
static PYBIND11_DESCR extents() { return _(""); }
|
||||
static void append_extents(list& /* shape */) { }
|
||||
};
|
||||
// Computes underlying type and a comma-separated list of extents for array
|
||||
// types (any mix of std::array and built-in arrays). An array of char is
|
||||
// treated as scalar because it gets special handling.
|
||||
template <typename T> struct array_info : array_info_scalar<T> { };
|
||||
template <typename T, size_t N> struct array_info<std::array<T, N>> {
|
||||
using type = typename array_info<T>::type;
|
||||
static constexpr bool is_array = true;
|
||||
static constexpr bool is_empty = (N == 0) || array_info<T>::is_empty;
|
||||
static constexpr size_t extent = N;
|
||||
|
||||
// appends the extents to shape
|
||||
static void append_extents(list& shape) {
|
||||
shape.append(N);
|
||||
array_info<T>::append_extents(shape);
|
||||
}
|
||||
|
||||
template<typename T2 = T, enable_if_t<!array_info<T2>::is_array, int> = 0>
|
||||
static PYBIND11_DESCR extents() {
|
||||
return _<N>();
|
||||
}
|
||||
|
||||
template<typename T2 = T, enable_if_t<array_info<T2>::is_array, int> = 0>
|
||||
static PYBIND11_DESCR extents() {
|
||||
return concat(_<N>(), array_info<T>::extents());
|
||||
}
|
||||
};
|
||||
// For numpy we have special handling for arrays of characters, so we don't include
|
||||
// the size in the array extents.
|
||||
template <size_t N> struct array_info<char[N]> : array_info_scalar<char[N]> { };
|
||||
template <size_t N> struct array_info<std::array<char, N>> : array_info_scalar<std::array<char, N>> { };
|
||||
template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<T, N>> { };
|
||||
template <typename T> using remove_all_extents_t = typename array_info<T>::type;
|
||||
|
||||
template <typename T> using is_pod_struct = all_of<
|
||||
std::is_pod<T>, // since we're accessing directly in memory we need a POD type
|
||||
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
|
||||
@@ -745,6 +785,8 @@ protected:
|
||||
|
||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||
public:
|
||||
static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
|
||||
|
||||
using value_type = T;
|
||||
|
||||
array_t() : array(0, static_cast<const T *>(nullptr)) {}
|
||||
@@ -871,6 +913,15 @@ struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct format_descriptor<T, detail::enable_if_t<detail::array_info<T>::is_array>> {
|
||||
static std::string format() {
|
||||
using detail::_;
|
||||
PYBIND11_DESCR extents = _("(") + detail::array_info<T>::extents() + _(")");
|
||||
return extents.text() + format_descriptor<detail::remove_all_extents_t<T>>::format();
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
template <typename T, int ExtraFlags>
|
||||
struct pyobject_caster<array_t<T, ExtraFlags>> {
|
||||
@@ -939,6 +990,20 @@ template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_F
|
||||
template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
|
||||
#undef PYBIND11_DECL_CHAR_FMT
|
||||
|
||||
template<typename T> struct npy_format_descriptor<T, enable_if_t<array_info<T>::is_array>> {
|
||||
private:
|
||||
using base_descr = npy_format_descriptor<typename array_info<T>::type>;
|
||||
public:
|
||||
static_assert(!array_info<T>::is_empty, "Zero-sized arrays are not supported");
|
||||
|
||||
static PYBIND11_DESCR name() { return _("(") + array_info<T>::extents() + _(")") + base_descr::name(); }
|
||||
static pybind11::dtype dtype() {
|
||||
list shape;
|
||||
array_info<T>::append_extents(shape);
|
||||
return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape));
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
|
||||
private:
|
||||
using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
|
||||
|
||||
Reference in New Issue
Block a user