Add function for comparing buffer_info formats to types

Allows equivalent integral types and numpy dtypes
This commit is contained in:
Patrick Stewart
2016-11-21 17:40:43 +00:00
committed by Wenzel Jakob
parent 5467979588
commit 0b6d08a008
5 changed files with 54 additions and 1 deletions

View File

@@ -623,6 +623,24 @@ template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is
template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];
NAMESPACE_BEGIN(detail)
template <typename T, typename SFINAE = void> struct compare_buffer_info {
static bool compare(const buffer_info& b) {
return b.format == format_descriptor<T>::format() && b.itemsize == sizeof(T);
}
};
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
static bool compare(const buffer_info& b) {
return b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
}
};
NAMESPACE_END(detail)
/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
PyObject *type, *value, *trace;

View File

@@ -703,6 +703,13 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};
template <typename T>
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
static bool compare(const buffer_info& b) {
return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
}
};
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
private:
// NB: the order here must match the one in common.h