mirror of
https://github.com/pybind/pybind11.git
synced 2026-05-24 23:07:00 +00:00
feat: make numpy.h compatible with both NumPy 1.x and 2.x (#5050)
* API: Make `numpy.h` compatible with both NumPy 1.x and 2.x * TST: Update numpy dtype flags test to not covert flags to char * API: Add `numpy2.h` instead and make `numpy.h` safe This means that users of `numpy.h` cannot be broken, but need to update to `numpy2.h` if they want to compile for NumPy 2. Using Macros simply and didn't bother to try to remove unnecessary code paths. * API: Rather than `numpy2.h` use a define for the user. * Thread `PYBIND11_NUMPY2_SUPPORT` through things and try to adept test matrix * Small fixups (shouldn't matter)? * Fixup. Does upgrading scipy help? (it shouldn't?) (Some other small fixup) * Use NumPy 2 nightlies for ubuntu-latest job also * BUG: Fix numpy.bool check * TST: Fix complexwarning * BUG: Fix the fact that only the 50 slot is filled with the copy alias (There were 3 functions all doing the same, only this slot survived 2.x) * TST: One more test tweak * TST: Use "long" name for long, since it changed on windows * TST: Apparently we didn't always have ulong, so just use `L` * TST: Enforce dtype='l' for test as default isn't long anymore on windows * Rename macro and invert logic to PYBIND11_NUMPY_1_ONLY * PYBIND11_INTERNAL_NUMPY_1_ONLY_DETECTED * Test and code comment expansion * CI: Use pre-releases of numpy/scipy from pip via explicit version * CI: NumPy 2 only available on almalinux (as it is Python >=3.9) * MAINT: Match name more exactly and adopt error phrasing * MAINT: Pushed early, move helper to be private member * fix error message compilation when using NumPy 1.x-only backcompat * silence name shadowing warning * chore: minor optimization Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com> --------- Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com> Co-authored-by: Ralf W. Grosse-Kunstleve <rwgk@google.com> Co-authored-by: Henry Schreiner <henryschreineriii@gmail.com>
This commit is contained in:
@@ -29,10 +29,15 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if defined(PYBIND11_NUMPY_1_ONLY) && !defined(PYBIND11_INTERNAL_NUMPY_1_ONLY_DETECTED)
|
||||
# error PYBIND11_NUMPY_1_ONLY must be defined before any pybind11 header is included.
|
||||
#endif
|
||||
|
||||
/* This will be true on all flat address space platforms and allows us to reduce the
|
||||
whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
|
||||
and dimension types (e.g. shape, strides, indexing), instead of inflicting this
|
||||
upon the library user. */
|
||||
upon the library user.
|
||||
Note that NumPy 2 now uses ssize_t for `npy_intp` to simplify this. */
|
||||
static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
|
||||
static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
|
||||
// We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares)
|
||||
@@ -53,7 +58,8 @@ struct handle_type_name<array> {
|
||||
template <typename type, typename SFINAE = void>
|
||||
struct npy_format_descriptor;
|
||||
|
||||
struct PyArrayDescr_Proxy {
|
||||
/* NumPy 1 proxy (always includes legacy fields) */
|
||||
struct PyArrayDescr1_Proxy {
|
||||
PyObject_HEAD
|
||||
PyObject *typeobj;
|
||||
char kind;
|
||||
@@ -68,6 +74,43 @@ struct PyArrayDescr_Proxy {
|
||||
PyObject *names;
|
||||
};
|
||||
|
||||
#ifndef PYBIND11_NUMPY_1_ONLY
|
||||
struct PyArrayDescr_Proxy {
|
||||
PyObject_HEAD
|
||||
PyObject *typeobj;
|
||||
char kind;
|
||||
char type;
|
||||
char byteorder;
|
||||
char _former_flags;
|
||||
int type_num;
|
||||
/* Additional fields are NumPy version specific. */
|
||||
};
|
||||
#else
|
||||
/* NumPy 1.x only, we can expose all fields */
|
||||
using PyArrayDescr_Proxy = PyArrayDescr1_Proxy;
|
||||
#endif
|
||||
|
||||
/* NumPy 2 proxy, including legacy fields */
|
||||
struct PyArrayDescr2_Proxy {
|
||||
PyObject_HEAD
|
||||
PyObject *typeobj;
|
||||
char kind;
|
||||
char type;
|
||||
char byteorder;
|
||||
char _former_flags;
|
||||
int type_num;
|
||||
std::uint64_t flags;
|
||||
ssize_t elsize;
|
||||
ssize_t alignment;
|
||||
PyObject *metadata;
|
||||
Py_hash_t hash;
|
||||
void *reserved_null[2];
|
||||
/* The following fields only exist if 0 <= type_num < 2056 */
|
||||
char *subarray;
|
||||
PyObject *fields;
|
||||
PyObject *names;
|
||||
};
|
||||
|
||||
struct PyArray_Proxy {
|
||||
PyObject_HEAD
|
||||
char *data;
|
||||
@@ -131,6 +174,14 @@ PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name
|
||||
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
|
||||
int major_version = numpy_version.attr("major").cast<int>();
|
||||
|
||||
#ifdef PYBIND11_NUMPY_1_ONLY
|
||||
if (major_version >= 2) {
|
||||
throw std::runtime_error(
|
||||
"This extension was built with PYBIND11_NUMPY_1_ONLY defined, "
|
||||
"but NumPy 2 is used in this process. For NumPy2 compatibility, "
|
||||
"this extension needs to be rebuilt without the PYBIND11_NUMPY_1_ONLY define.");
|
||||
}
|
||||
#endif
|
||||
/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
|
||||
became a private module. */
|
||||
std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
|
||||
@@ -203,6 +254,8 @@ struct npy_api {
|
||||
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
|
||||
};
|
||||
|
||||
unsigned int PyArray_RUNTIME_VERSION_;
|
||||
|
||||
struct PyArray_Dims {
|
||||
Py_intptr_t *ptr;
|
||||
int len;
|
||||
@@ -241,6 +294,7 @@ struct npy_api {
|
||||
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
|
||||
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
|
||||
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
|
||||
#ifdef PYBIND11_NUMPY_1_ONLY
|
||||
int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
|
||||
PyObject *,
|
||||
unsigned char,
|
||||
@@ -249,6 +303,7 @@ struct npy_api {
|
||||
Py_intptr_t *,
|
||||
PyObject **,
|
||||
PyObject *);
|
||||
#endif
|
||||
PyObject *(*PyArray_Squeeze_)(PyObject *);
|
||||
// Unused. Not removed because that affects ABI of the class.
|
||||
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
||||
@@ -266,7 +321,8 @@ private:
|
||||
API_PyArray_DescrFromScalar = 57,
|
||||
API_PyArray_FromAny = 69,
|
||||
API_PyArray_Resize = 80,
|
||||
API_PyArray_CopyInto = 82,
|
||||
// CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
|
||||
API_PyArray_CopyInto = 50,
|
||||
API_PyArray_NewCopy = 85,
|
||||
API_PyArray_NewFromDescr = 94,
|
||||
API_PyArray_DescrNewFromType = 96,
|
||||
@@ -275,7 +331,9 @@ private:
|
||||
API_PyArray_View = 137,
|
||||
API_PyArray_DescrConverter = 174,
|
||||
API_PyArray_EquivTypes = 182,
|
||||
#ifdef PYBIND11_NUMPY_1_ONLY
|
||||
API_PyArray_GetArrayParamsFromObject = 278,
|
||||
#endif
|
||||
API_PyArray_SetBaseObject = 282
|
||||
};
|
||||
|
||||
@@ -290,7 +348,8 @@ private:
|
||||
npy_api api;
|
||||
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
|
||||
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
|
||||
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7) {
|
||||
api.PyArray_RUNTIME_VERSION_ = api.PyArray_GetNDArrayCFeatureVersion_();
|
||||
if (api.PyArray_RUNTIME_VERSION_ < 0x7) {
|
||||
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
|
||||
}
|
||||
DECL_NPY_API(PyArray_Type);
|
||||
@@ -309,7 +368,9 @@ private:
|
||||
DECL_NPY_API(PyArray_View);
|
||||
DECL_NPY_API(PyArray_DescrConverter);
|
||||
DECL_NPY_API(PyArray_EquivTypes);
|
||||
#ifdef PYBIND11_NUMPY_1_ONLY
|
||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||
#endif
|
||||
DECL_NPY_API(PyArray_SetBaseObject);
|
||||
|
||||
#undef DECL_NPY_API
|
||||
@@ -331,6 +392,14 @@ inline const PyArrayDescr_Proxy *array_descriptor_proxy(const PyObject *ptr) {
|
||||
return reinterpret_cast<const PyArrayDescr_Proxy *>(ptr);
|
||||
}
|
||||
|
||||
inline const PyArrayDescr1_Proxy *array_descriptor1_proxy(const PyObject *ptr) {
|
||||
return reinterpret_cast<const PyArrayDescr1_Proxy *>(ptr);
|
||||
}
|
||||
|
||||
inline const PyArrayDescr2_Proxy *array_descriptor2_proxy(const PyObject *ptr) {
|
||||
return reinterpret_cast<const PyArrayDescr2_Proxy *>(ptr);
|
||||
}
|
||||
|
||||
inline bool check_flags(const void *ptr, int flag) {
|
||||
return (flag == (array_proxy(ptr)->flags & flag));
|
||||
}
|
||||
@@ -610,10 +679,32 @@ public:
|
||||
}
|
||||
|
||||
/// Size of the data type in bytes.
|
||||
#ifdef PYBIND11_NUMPY_1_ONLY
|
||||
ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
|
||||
#else
|
||||
ssize_t itemsize() const {
|
||||
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
|
||||
return detail::array_descriptor1_proxy(m_ptr)->elsize;
|
||||
}
|
||||
return detail::array_descriptor2_proxy(m_ptr)->elsize;
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Returns true for structured data types.
|
||||
#ifdef PYBIND11_NUMPY_1_ONLY
|
||||
bool has_fields() const { return detail::array_descriptor_proxy(m_ptr)->names != nullptr; }
|
||||
#else
|
||||
bool has_fields() const {
|
||||
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
|
||||
return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
|
||||
}
|
||||
const auto *proxy = detail::array_descriptor2_proxy(m_ptr);
|
||||
if (proxy->type_num < 0 || proxy->type_num >= 2056) {
|
||||
return false;
|
||||
}
|
||||
return proxy->names != nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Single-character code for dtype's kind.
|
||||
/// For example, floating point types are 'f' and integral types are 'i'.
|
||||
@@ -639,11 +730,29 @@ public:
|
||||
/// Single character for byteorder
|
||||
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
|
||||
|
||||
/// Alignment of the data type
|
||||
/// Alignment of the data type
|
||||
#ifdef PYBIND11_NUMPY_1_ONLY
|
||||
int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }
|
||||
#else
|
||||
ssize_t alignment() const {
|
||||
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
|
||||
return detail::array_descriptor1_proxy(m_ptr)->alignment;
|
||||
}
|
||||
return detail::array_descriptor2_proxy(m_ptr)->alignment;
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Flags for the array descriptor
|
||||
/// Flags for the array descriptor
|
||||
#ifdef PYBIND11_NUMPY_1_ONLY
|
||||
char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }
|
||||
#else
|
||||
std::uint64_t flags() const {
|
||||
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
|
||||
return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
|
||||
}
|
||||
return detail::array_descriptor2_proxy(m_ptr)->flags;
|
||||
}
|
||||
#endif
|
||||
|
||||
private:
|
||||
static object &_dtype_from_pep3118() {
|
||||
@@ -810,9 +919,7 @@ public:
|
||||
}
|
||||
|
||||
/// Byte size of a single element
|
||||
ssize_t itemsize() const {
|
||||
return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
|
||||
}
|
||||
ssize_t itemsize() const { return dtype().itemsize(); }
|
||||
|
||||
/// Total number of bytes
|
||||
ssize_t nbytes() const { return size() * itemsize(); }
|
||||
|
||||
Reference in New Issue
Block a user