Strip padding fields in dtypes, update the tests

This commit is contained in:
Ivan Smirnov
2016-07-06 00:28:12 +01:00
parent 13022f1b8c
commit 8fa09cb871
4 changed files with 178 additions and 56 deletions

View File

@@ -15,6 +15,7 @@
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <initializer_list>
#if defined(_MSC_VER)
@@ -26,6 +27,8 @@ NAMESPACE_BEGIN(pybind11)
namespace detail {
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
object fix_dtype(object);
template <typename T>
struct is_pod_struct {
enum { value = std::is_pod<T>::value && // offsetof only works correctly for POD types
@@ -47,7 +50,9 @@ public:
API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 9,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
NPY_C_CONTIGUOUS_ = 0x0001,
@@ -61,7 +66,9 @@ public:
NPY_LONG_, NPY_ULONG_,
NPY_LONGLONG_, NPY_ULONGLONG_,
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
NPY_OBJECT_ = 17,
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
};
static API lookup() {
@@ -79,7 +86,9 @@ public:
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API
return api;
@@ -91,10 +100,12 @@ public:
PyObject *(*PyArray_NewFromDescr_)
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
Py_intptr_t *, void *, int, PyObject *);
PyObject *(*PyArray_DescrNewFromType_)(int);
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_;
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
};
@@ -113,52 +124,83 @@ public:
Py_intptr_t shape = (Py_intptr_t) size;
object tmp = object(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
if (ptr && tmp)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
if (ptr)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
m_ptr = tmp.release().ptr();
}
array(const buffer_info &info) {
PyObject *arr = nullptr, *descr = nullptr;
int ndim = 0;
Py_ssize_t dims[32];
API& api = lookup_api();
auto& api = lookup_api();
// Allocate non-zeroed memory if it hasn't been provided by the caller.
// Normally, we could leave this null for NumPy to allocate memory for us, but
// since we need a memoryview, the data pointer has to be non-null. NumPy uses
// malloc if NPY_NEEDS_INIT is not set (in which case it uses calloc); however,
// we don't have a desriptor yet (only a buffer format string), so we can't
// access the flags. As long as we're not dealing with object dtypes/fields
// though, the memory doesn't have to be zeroed so we use malloc.
auto buf_info = info;
if (!buf_info.ptr)
// always allocate at least 1 element, same way as NumPy does it
buf_info.ptr = std::malloc(std::max(info.size, (size_t) 1) * info.itemsize);
if (!buf_info.ptr)
pybind11_fail("NumPy: failed to allocate memory for buffer");
// _dtype_from_pep3118 returns dtypes with padding fields in, however the array
// constructor seems to then consume them, so we don't need to strip them ourselves
auto numpy_internal = module::import("numpy.core._internal");
auto dtype_from_fmt = (object) numpy_internal.attr("_dtype_from_pep3118");
auto dtype = dtype_from_fmt(pybind11::str(info.format));
auto dtype2 = strip_padding_fields(dtype);
// PyArray_GetArrayParamsFromObject seems to be the only low-level API function
// that will accept arbitrary buffers (including structured types)
auto view = memoryview(buf_info);
auto res = api.PyArray_GetArrayParamsFromObject_(view.ptr(), nullptr, 1, &descr,
&ndim, dims, &arr, nullptr);
if (res < 0 || !arr || descr)
// We expect arr to have a pointer to a newly created array, in which case all
// other parameters like descr would be set to null, according to the API.
pybind11_fail("NumPy: unable to convert buffer to an array");
m_ptr = arr;
object tmp(api.PyArray_NewFromDescr_(
api.PyArray_Type_, dtype2.release().ptr(), (int) info.ndim, (Py_intptr_t *) &info.shape[0],
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
if (info.ptr)
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
m_ptr = tmp.release().ptr();
auto d = (object) this->attr("dtype");
}
protected:
// protected:
static API &lookup_api() {
static API api = API::lookup();
return api;
}
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
static object strip_padding_fields(object dtype) {
// Recursively strip all void fields with empty names that are generated for
// padding fields (as of NumPy v1.11).
auto fields = dtype.attr("fields").cast<object>();
if (fields.ptr() == Py_None)
return dtype;
struct field_descr { pybind11::str name; object format; int_ offset; };
std::vector<field_descr> field_descriptors;
auto items = fields.attr("items").cast<object>();
for (auto field : items()) {
auto spec = object(field, true).cast<tuple>();
auto name = spec[0].cast<pybind11::str>();
auto format = spec[1].cast<tuple>()[0].cast<object>();
auto offset = spec[1].cast<tuple>()[1].cast<int_>();
if (!len(name) && (std::string) dtype.attr("kind").cast<pybind11::str>() == "V")
continue;
field_descriptors.push_back({name, strip_padding_fields(format), offset});
}
std::sort(field_descriptors.begin(), field_descriptors.end(),
[](const field_descr& a, const field_descr& b) {
return (int) a.offset < (int) b.offset;
});
list names, formats, offsets;
for (auto& descr : field_descriptors) {
names.append(descr.name);
formats.append(descr.format);
offsets.append(descr.offset);
}
auto args = dict();
args["names"] = names; args["formats"] = formats; args["offsets"] = offsets;
args["itemsize"] = dtype.attr("itemsize").cast<int_>();
PyObject *descr = nullptr;
if (!lookup_api().PyArray_DescrConverter_(args.release().ptr(), &descr) || !descr)
pybind11_fail("NumPy: failed to create structured dtype");
return object(descr, false);
}
};
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
@@ -233,9 +275,12 @@ DECL_FMT(std::complex<double>, NPY_CDOUBLE_, "complex128");
struct field_descriptor {
const char *name;
size_t offset;
size_t size;
const char *format;
object descr;
};
template <typename T>
struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>::type> {
static PYBIND11_DESCR name() { return _("user-defined"); }
@@ -253,7 +298,7 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
}
static void register_dtype(std::initializer_list<field_descriptor> fields) {
array::API& api = array::lookup_api();
auto& api = array::lookup_api();
auto args = dict();
list names { }, offsets { }, formats { };
for (auto field : fields) {
@@ -263,26 +308,47 @@ struct npy_format_descriptor<T, typename std::enable_if<is_pod_struct<T>::value>
offsets.append(int_(field.offset));
formats.append(field.descr);
}
args["names"] = names;
args["offsets"] = offsets;
args["formats"] = formats;
args["names"] = names; args["offsets"] = offsets; args["formats"] = formats;
args["itemsize"] = int_(sizeof(T));
// This is essentially the same as calling np.dtype() constructor in Python and passing
// it a dict of the form {'names': ..., 'formats': ..., 'offsets': ...}.
if (!api.PyArray_DescrConverter_(args.release().ptr(), &dtype_()) || !dtype_())
pybind11_fail("NumPy: failed to create structured dtype");
// Let NumPy figure the buffer format string for us: memoryview(np.empty(0, dtype)).format
auto np = module::import("numpy");
auto empty = (object) np.attr("empty");
if (auto arr = (object) empty(int_(0), dtype())) {
if (auto view = PyMemoryView_FromObject(arr.ptr())) {
if (auto info = PyMemoryView_GET_BUFFER(view)) {
std::strncpy(format_(), info->format, 4096);
return;
}
}
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
// get fixed in v1.12; for further details, see these:
// - https://github.com/numpy/numpy/issues/7797
// - https://github.com/numpy/numpy/pull/7798
// Because of this, we won't use numpy's logic to generate buffer format
// strings and will just do it ourselves.
std::vector<field_descriptor> ordered_fields(fields);
std::sort(ordered_fields.begin(), ordered_fields.end(),
[](const field_descriptor& a, const field_descriptor &b) {
return a.offset < b.offset;
});
size_t offset = 0;
std::ostringstream oss;
oss << "T{";
for (auto& field : ordered_fields) {
if (field.offset > offset)
oss << (field.offset - offset) << 'x';
// note that '=' is required to cover the case of unaligned fields
oss << '=' << field.format << ':' << field.name << ':';
offset = field.offset + field.size;
}
pybind11_fail("NumPy: failed to extract buffer format");
if (sizeof(T) > offset)
oss << (sizeof(T) - offset) << 'x';
oss << '}';
std::strncpy(format_(), oss.str().c_str(), 4096);
// Sanity check: verify that NumPy properly parses our buffer format string
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1, { 0 }, { sizeof(T) }));
auto dtype = (object) arr.attr("dtype");
auto fixed_dtype = dtype;
// auto fixed_dtype = array::strip_padding_fields(object(dtype_(), true));
// if (!api.PyArray_EquivTypes_(dtype_(), fixed_dtype.ptr()))
// pybind11_fail("NumPy: invalid buffer descriptor!");
}
private:
@@ -293,7 +359,8 @@ private:
// Extract name, offset and format descriptor for a struct field
#define PYBIND11_FIELD_DESCRIPTOR(Type, Field) \
::pybind11::detail::field_descriptor { \
#Field, offsetof(Type, Field), \
#Field, offsetof(Type, Field), sizeof(decltype(static_cast<Type*>(0)->Field)), \
::pybind11::format_descriptor<decltype(static_cast<Type*>(0)->Field)>::format(), \
::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::dtype() \
}