mirror of
https://github.com/pybind/pybind11.git
synced 2026-05-13 09:46:10 +00:00
Strip padding fields in dtypes, update the tests
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <initializer_list>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
@@ -26,6 +27,8 @@ NAMESPACE_BEGIN(pybind11)
|
||||
namespace detail {
|
||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
||||
|
||||
object fix_dtype(object);
|
||||
|
||||
template <typename T>
|
||||
struct is_pod_struct {
|
||||
enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
|
||||
@@ -47,7 +50,9 @@ public:
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94,
|
||||
API_PyArray_DescrNewFromType = 9,
|
||||
API_PyArray_DescrConverter = 174,
|
||||
API_PyArray_EquivTypes = 182,
|
||||
API_PyArray_GetArrayParamsFromObject = 278,
|
||||
|
||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||
@@ -61,7 +66,9 @@ public:
|
||||
NPY_LONG_, NPY_ULONG_,
|
||||
NPY_LONGLONG_, NPY_ULONGLONG_,
|
||||
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
|
||||
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_
|
||||
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
|
||||
NPY_OBJECT_ = 17,
|
||||
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
|
||||
};
|
||||
|
||||
static API lookup() {
|
||||
@@ -79,7 +86,9 @@ public:
|
||||
DECL_NPY_API(PyArray_FromAny);
|
||||
DECL_NPY_API(PyArray_NewCopy);
|
||||
DECL_NPY_API(PyArray_NewFromDescr);
|
||||
DECL_NPY_API(PyArray_DescrNewFromType);
|
||||
DECL_NPY_API(PyArray_DescrConverter);
|
||||
DECL_NPY_API(PyArray_EquivTypes);
|
||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||
#undef DECL_NPY_API
|
||||
return api;
|
||||
@@ -91,10 +100,12 @@ public:
|
||||
PyObject *(*PyArray_NewFromDescr_)
|
||||
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
|
||||
Py_intptr_t *, void *, int, PyObject *);
|
||||
PyObject *(*PyArray_DescrNewFromType_)(int);
|
||||
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
||||
PyTypeObject *PyArray_Type_;
|
||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
|
||||
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
|
||||
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
||||
Py_ssize_t *, PyObject **, PyObject *);
|
||||
};
|
||||
@@ -113,52 +124,83 @@ public:
|
||||
Py_intptr_t shape = (Py_intptr_t) size;
|
||||
object tmp = object(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
|
||||
if (ptr && tmp)
|
||||
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||
if (!tmp)
|
||||
pybind11_fail("NumPy: unable to create array!");
|
||||
if (ptr)
|
||||
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||
m_ptr = tmp.release().ptr();
|
||||
}
|
||||
|
||||
array(const buffer_info &info) {
|
||||
PyObject *arr = nullptr, *descr = nullptr;
|
||||
int ndim = 0;
|
||||
Py_ssize_t dims[32];
|
||||
API& api = lookup_api();
|
||||
auto& api = lookup_api();
|
||||
|
||||
// Allocate non-zeroed memory if it hasn't been provided by the caller.
|
||||
// Normally, we could leave this null for NumPy to allocate memory for us, but
|
||||
// since we need a memoryview, the data pointer has to be non-null. NumPy uses
|
||||
// malloc if NPY_NEEDS_INIT is not set (in which case it uses calloc); however,
|
||||
// we don't have a desriptor yet (only a buffer format string), so we can't
|
||||
// access the flags. As long as we're not dealing with object dtypes/fields
|
||||
// though, the memory doesn't have to be zeroed so we use malloc.
|
||||
auto buf_info = info;
|
||||
if (!buf_info.ptr)
|
||||
// always allocate at least 1 element, same way as NumPy does it
|
||||
buf_info.ptr = std::malloc(std::max(info.size, (size_t) 1) * info.itemsize);
|
||||
if (!buf_info.ptr)
|
||||
pybind11_fail("NumPy: failed to allocate memory for buffer");
|
||||
// _dtype_from_pep3118 returns dtypes with padding fields in, however the array
|
||||
// constructor seems to then consume them, so we don't need to strip them ourselves
|
||||
auto numpy_internal = module::import("numpy.core._internal");
|
||||
auto dtype_from_fmt = (object) numpy_internal.attr("_dtype_from_pep3118");
|
||||
auto dtype = dtype_from_fmt(pybind11::str(info.format));
|
||||
auto dtype2 = strip_padding_fields(dtype);
|
||||
|
||||
// PyArray_GetArrayParamsFromObject seems to be the only low-level API function
|
||||
// that will accept arbitrary buffers (including structured types)
|
||||
auto view = memoryview(buf_info);
|
||||
auto res = api.PyArray_GetArrayParamsFromObject_(view.ptr(), nullptr, 1, &descr,
|
||||
&ndim, dims, &arr, nullptr);
|
||||
if (res < 0 || !arr || descr)
|
||||
// We expect arr to have a pointer to a newly created array, in which case all
|
||||
// other parameters like descr would be set to null, according to the API.
|
||||
pybind11_fail("NumPy: unable to convert buffer to an array");
|
||||
m_ptr = arr;
|
||||
object tmp(api.PyArray_NewFromDescr_(
|
||||
api.PyArray_Type_, dtype2.release().ptr(), (int) info.ndim, (Py_intptr_t *) &info.shape[0],
|
||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
|
||||
if (!tmp)
|
||||
pybind11_fail("NumPy: unable to create array!");
|
||||
if (info.ptr)
|
||||
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
||||
m_ptr = tmp.release().ptr();
|
||||
auto d = (object) this->attr("dtype");
|
||||
}
|
||||
|
||||
protected:
|
||||
// protected:
|
||||
static API &lookup_api() {
|
||||
static API api = API::lookup();
|
||||
return api;
|
||||
}
|
||||
|
||||
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
|
||||
|
||||
static object strip_padding_fields(object dtype) {
|
||||
// Recursively strip all void fields with empty names that are generated for
|
||||
// padding fields (as of NumPy v1.11).
|
||||
auto fields = dtype.attr("fields").cast<object>();
|
||||
if (fields.ptr() == Py_None)
|
||||
return dtype;
|
||||
|
||||
struct field_descr { pybind11::str name; object format; int_ offset; };
|
||||
std::vector<field_descr> field_descriptors;
|
||||
|
||||
auto items = fields.attr("items").cast<object>();
|
||||
for (auto field : items()) {
|
||||
auto spec = object(field, true).cast<tuple>();
|
||||
auto name = spec[0].cast<pybind11::str>();
|
||||
auto format = spec[1].cast<tuple>()[0].cast<object>();
|
||||
auto offset = spec[1].cast<tuple>()[1].cast<int_>();
|
||||
if (!len(name) && (std::string) dtype.attr("kind").cast<pybind11::str>() == "V")
|
||||
continue;
|
||||
field_descriptors.push_back({name, strip_padding_fields(format), offset});
|
||||
}
|
||||
|
||||
std::sort(field_descriptors.begin(), field_descriptors.end(),
|
||||
[](const field_descr& a, const field_descr& b) {
|
||||
return (int) a.offset < (int) b.offset;
|
||||
});
|
||||
|
||||
list names, formats, offsets;
|
||||
for (auto& descr : field_descriptors) {
|
||||
names.append(descr.name);
|
||||
formats.append(descr.format);
|
||||
offsets.append(descr.offset);
|
||||
}
|
||||
auto args = dict();
|
||||
args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
|
||||
args["itemsize"] = dtype.attr("itemsize").cast<int_>();
|
||||
|
||||
PyObject *descr = nullptr;
|
||||
if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
|
||||
pybind11_fail("NumPy: failed to create structured dtype");
|
||||
return object(descr, false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||
@@ -233,9 +275,12 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
|
||||
struct field_descriptor {
|
||||
const char *name;
|
||||
size_t offset;
|
||||
size_t size;
|
||||
const char *format;
|
||||
object descr;
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type> {
|
||||
static PYBIND11_DESCR name() { return _("user-defined"); }
|
||||
@@ -253,7 +298,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
|
||||
}
|
||||
|
||||
static void register_dtype(std::initializer_list<field_descriptor> fields) {
|
||||
array::API& api = array::lookup_api();
|
||||
auto& api = array::lookup_api();
|
||||
auto args = dict();
|
||||
list names { }, offsets { }, formats { };
|
||||
for (auto field : fields) {
|
||||
@@ -263,26 +308,47 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
|
||||
offsets.append(int_(field.offset));
|
||||
formats.append(field.descr);
|
||||
}
|
||||
args["names"] = names;
|
||||
args["offsets"] = offsets;
|
||||
args["formats"] = formats;
|
||||
args["names"] = names; args["offsets"] = offsets; args["formats"] = formats;
|
||||
args["itemsize"] = int_(sizeof(T));
|
||||
// This is essentially the same as calling np.dtype() constructor in Python and passing
|
||||
// it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}.
|
||||
if (!api.PyArray_DescrConverter_(args.release().ptr(), &dtype_()) || !dtype_())
|
||||
pybind11_fail("NumPy: failed to create structured dtype");
|
||||
// Let NumPy figure the buffer format string for us: memoryview(np.empty(0, dtype)).format
|
||||
auto np = module::import("numpy");
|
||||
auto empty = (object) np.attr("empty");
|
||||
if (auto arr = (object) empty(int_(0), dtype())) {
|
||||
if (auto view = PyMemoryView_FromObject(arr.ptr())) {
|
||||
if (auto info = PyMemoryView_GET_BUFFER(view)) {
|
||||
std::strncpy(format_(), info->format, 4096);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
|
||||
// not encoded explicitly into the format string. This will supposedly
|
||||
// get fixed in v1.12; for further details, see these:
|
||||
// - https://github.com/numpy/numpy/issues/7797
|
||||
// - https://github.com/numpy/numpy/pull/7798
|
||||
// Because of this, we won't use numpy's logic to generate buffer format
|
||||
// strings and will just do it ourselves.
|
||||
std::vector<field_descriptor> ordered_fields(fields);
|
||||
std::sort(ordered_fields.begin(), ordered_fields.end(),
|
||||
[](const field_descriptor& a, const field_descriptor &b) {
|
||||
return a.offset < b.offset;
|
||||
});
|
||||
size_t offset = 0;
|
||||
std::ostringstream oss;
|
||||
oss << "T{";
|
||||
for (auto& field : ordered_fields) {
|
||||
if (field.offset > offset)
|
||||
oss << (field.offset - offset) << 'x';
|
||||
// note that '=' is required to cover the case of unaligned fields
|
||||
oss << '=' << field.format << ':' << field.name << ':';
|
||||
offset = field.offset + field.size;
|
||||
}
|
||||
pybind11_fail("NumPy: failed to extract buffer format");
|
||||
if (sizeof(T) > offset)
|
||||
oss << (sizeof(T) - offset) << 'x';
|
||||
oss << '}';
|
||||
std::strncpy(format_(), oss.str().c_str(), 4096);
|
||||
|
||||
// Sanity check: verify that NumPy properly parses our buffer format string
|
||||
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) }));
|
||||
auto dtype = (object) arr.attr("dtype");
|
||||
auto fixed_dtype = dtype;
|
||||
// auto fixed_dtype = array::strip_padding_fields(object(dtype_(), true));
|
||||
// if (!api.PyArray_EquivTypes_(dtype_(), fixed_dtype.ptr()))
|
||||
// pybind11_fail("NumPy: invalid buffer descriptor!");
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -293,7 +359,8 @@ private:
|
||||
// Extract name, offset and format descriptor for a struct field
|
||||
#define PYBIND11_FIELD_DESCRIPTOR(Type, Field) \
|
||||
::pybind11::detail::field_descriptor { \
|
||||
#Field, offsetof(Type, Field), \
|
||||
#Field, offsetof(Type, Field), sizeof(decltype(static_cast<Type*>(0)->Field)), \
|
||||
::pybind11::format_descriptor<decltype(static_cast<Type*>(0)->Field)>::format(), \
|
||||
::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::dtype() \
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user