feat: numpy scalars (#5726)

This commit is contained in:
Henry Schreiner
2025-06-18 19:40:31 -04:00
committed by GitHub
parent c60c14991d
commit cf3d1a75a2
5 changed files with 319 additions and 36 deletions

View File

@@ -232,6 +232,46 @@ prevent many types of unsupported structures, it is still the user's
responsibility to use only "plain" structures that can be safely manipulated as
raw memory without violating invariants.
Scalar types
============
In some cases we may want to accept or return NumPy scalar values such as
``np.float32`` or ``np.float64``. We hope to be able to handle single-precision
and double-precision on the C-side. However, both are bound to Python's
double-precision builtin float by default, so they cannot be processed separately.
We used the ``py::buffer`` trick to implement the previous approach, which
will cause the readability of the code to drop significantly.
Luckily, there's a helper type for this occasion - ``py::numpy_scalar``:
.. code-block:: cpp
m.def("add", [](py::numpy_scalar<float> a, py::numpy_scalar<float> b) {
return py::make_scalar(a + b);
});
m.def("add", [](py::numpy_scalar<double> a, py::numpy_scalar<double> b) {
return py::make_scalar(a + b);
});
This type is trivially convertible to and from the type it wraps; currently
supported scalar types are NumPy arithmetic types: ``bool_``, ``int8``,
``int16``, ``int32``, ``int64``, ``uint8``, ``uint16``, ``uint32``,
``uint64``, ``float32``, ``float64``, ``complex64``, ``complex128``, all of
them mapping to respective C++ counterparts.
.. note::
``py::numpy_scalar<T>`` strictly matches NumPy scalar types. For example,
``py::numpy_scalar<int64_t>`` will accept ``np.int64(123)``,
but **not** a regular Python ``int`` like ``123``.
.. note::
Native C types are mapped to NumPy types in a platform specific way: for
instance, ``char`` may be mapped to either ``np.int8`` or ``np.uint8``
and ``long`` may use 4 or 8 bytes depending on the platform. Unless you
clearly understand the difference and your needs, please use ``<cstdint>``.
Vectorizing functions
=====================

View File

@@ -49,6 +49,9 @@ PYBIND11_WARNING_DISABLE_MSVC(4127)
class dtype; // Forward declaration
class array; // Forward declaration
template <typename>
struct numpy_scalar; // Forward declaration
PYBIND11_NAMESPACE_BEGIN(detail)
template <>
@@ -245,6 +248,21 @@ struct npy_api {
NPY_UINT64_
= platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
NPY_FLOAT32_ = platform_lookup<float, double, float, long double>(
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_FLOAT64_ = platform_lookup<double, double, float, long double>(
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_COMPLEX64_
= platform_lookup<std::complex<float>,
std::complex<double>,
std::complex<float>,
std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_COMPLEX128_
= platform_lookup<std::complex<double>,
std::complex<double>,
std::complex<float>,
std::complex<long double>>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_CHAR_ = std::is_signed<char>::value ? NPY_BYTE_ : NPY_UBYTE_,
};
unsigned int PyArray_RUNTIME_VERSION_;
@@ -268,6 +286,7 @@ struct npy_api {
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
PyObject *(*PyArray_DescrFromType_)(int);
PyObject *(*PyArray_TypeObjectFromType_)(int);
PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
PyObject *,
int,
@@ -284,6 +303,8 @@ struct npy_api {
PyTypeObject *PyVoidArrType_Type_;
PyTypeObject *PyArrayDescr_Type_;
PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *);
void (*PyArray_ScalarAsCtype_)(PyObject *, void *);
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
@@ -301,7 +322,10 @@ private:
API_PyArrayDescr_Type = 3,
API_PyVoidArrType_Type = 39,
API_PyArray_DescrFromType = 45,
API_PyArray_TypeObjectFromType = 46,
API_PyArray_DescrFromScalar = 57,
API_PyArray_Scalar = 60,
API_PyArray_ScalarAsCtype = 62,
API_PyArray_FromAny = 69,
API_PyArray_Resize = 80,
// CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
@@ -336,7 +360,10 @@ private:
DECL_NPY_API(PyVoidArrType_Type);
DECL_NPY_API(PyArrayDescr_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_TypeObjectFromType);
DECL_NPY_API(PyArray_DescrFromScalar);
DECL_NPY_API(PyArray_Scalar);
DECL_NPY_API(PyArray_ScalarAsCtype);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_Resize);
DECL_NPY_API(PyArray_CopyInto);
@@ -355,6 +382,83 @@ private:
}
};
template <typename T>
struct is_complex : std::false_type {};
template <typename T>
struct is_complex<std::complex<T>> : std::true_type {};
template <typename T, typename = void>
struct npy_format_descriptor_name;
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
static constexpr auto name = const_name<std::is_same<T, bool>::value>(
const_name("numpy.bool"),
const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
+ const_name<sizeof(T) * 8>());
};
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
static constexpr auto name = const_name < std::is_same<T, float>::value
|| std::is_same<T, const float>::value
|| std::is_same<T, double>::value
|| std::is_same<T, const double>::value
> (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
const_name("numpy.longdouble"));
};
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
|| std::is_same<typename T::value_type, const float>::value
|| std::is_same<typename T::value_type, double>::value
|| std::is_same<typename T::value_type, const double>::value
> (const_name("numpy.complex")
+ const_name<sizeof(typename T::value_type) * 16>(),
const_name("numpy.longcomplex"));
};
template <typename T>
struct numpy_scalar_info {};
#define PYBIND11_NUMPY_SCALAR_IMPL(ctype_, typenum_) \
template <> \
struct numpy_scalar_info<ctype_> { \
static constexpr auto name = npy_format_descriptor_name<ctype_>::name; \
static constexpr int typenum = npy_api::typenum_##_; \
}
// boolean type
PYBIND11_NUMPY_SCALAR_IMPL(bool, NPY_BOOL);
// character types
PYBIND11_NUMPY_SCALAR_IMPL(char, NPY_CHAR);
PYBIND11_NUMPY_SCALAR_IMPL(signed char, NPY_BYTE);
PYBIND11_NUMPY_SCALAR_IMPL(unsigned char, NPY_UBYTE);
// signed integer types
PYBIND11_NUMPY_SCALAR_IMPL(std::int16_t, NPY_INT16);
PYBIND11_NUMPY_SCALAR_IMPL(std::int32_t, NPY_INT32);
PYBIND11_NUMPY_SCALAR_IMPL(std::int64_t, NPY_INT64);
// unsigned integer types
PYBIND11_NUMPY_SCALAR_IMPL(std::uint16_t, NPY_UINT16);
PYBIND11_NUMPY_SCALAR_IMPL(std::uint32_t, NPY_UINT32);
PYBIND11_NUMPY_SCALAR_IMPL(std::uint64_t, NPY_UINT64);
// floating point types
PYBIND11_NUMPY_SCALAR_IMPL(float, NPY_FLOAT);
PYBIND11_NUMPY_SCALAR_IMPL(double, NPY_DOUBLE);
PYBIND11_NUMPY_SCALAR_IMPL(long double, NPY_LONGDOUBLE);
// complex types
PYBIND11_NUMPY_SCALAR_IMPL(std::complex<float>, NPY_CFLOAT);
PYBIND11_NUMPY_SCALAR_IMPL(std::complex<double>, NPY_CDOUBLE);
PYBIND11_NUMPY_SCALAR_IMPL(std::complex<long double>, NPY_CLONGDOUBLE);
#undef PYBIND11_NUMPY_SCALAR_IMPL
// This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ...
// This is needed to correctly handle situations where multiple typenums map to the same type,
// e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different
@@ -453,10 +557,6 @@ template <typename T>
struct is_std_array : std::false_type {};
template <typename T, size_t N>
struct is_std_array<std::array<T, N>> : std::true_type {};
template <typename T>
struct is_complex : std::false_type {};
template <typename T>
struct is_complex<std::complex<T>> : std::true_type {};
template <typename T>
struct array_info_scalar {
@@ -670,8 +770,65 @@ template <typename T, ssize_t Dim>
struct type_caster<unchecked_mutable_reference<T, Dim>>
: type_caster<unchecked_reference<T, Dim>> {};
template <typename T>
struct type_caster<numpy_scalar<T>> {
using value_type = T;
using type_info = numpy_scalar_info<T>;
PYBIND11_TYPE_CASTER(numpy_scalar<T>, type_info::name);
static handle &target_type() {
static handle tp = npy_api::get().PyArray_TypeObjectFromType_(type_info::typenum);
return tp;
}
static handle &target_dtype() {
static handle tp = npy_api::get().PyArray_DescrFromType_(type_info::typenum);
return tp;
}
bool load(handle src, bool) {
if (isinstance(src, target_type())) {
npy_api::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value);
return true;
}
return false;
}
static handle cast(numpy_scalar<T> src, return_value_policy, handle) {
return npy_api::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr);
}
};
PYBIND11_NAMESPACE_END(detail)
template <typename T>
struct numpy_scalar {
using value_type = T;
value_type value;
numpy_scalar() = default;
explicit numpy_scalar(value_type value) : value(value) {}
explicit operator value_type() const { return value; }
numpy_scalar &operator=(value_type value) {
this->value = value;
return *this;
}
friend bool operator==(const numpy_scalar &a, const numpy_scalar &b) {
return a.value == b.value;
}
friend bool operator!=(const numpy_scalar &a, const numpy_scalar &b) { return !(a == b); }
};
template <typename T>
numpy_scalar<T> make_scalar(T value) {
return numpy_scalar<T>(value);
}
class dtype : public object {
public:
PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_)
@@ -1409,38 +1566,6 @@ struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::valu
}
};
template <typename T, typename = void>
struct npy_format_descriptor_name;
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_integral<T>::value>> {
static constexpr auto name = const_name<std::is_same<T, bool>::value>(
const_name("bool"),
const_name<std::is_signed<T>::value>("numpy.int", "numpy.uint")
+ const_name<sizeof(T) * 8>());
};
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<std::is_floating_point<T>::value>> {
static constexpr auto name = const_name < std::is_same<T, float>::value
|| std::is_same<T, const float>::value
|| std::is_same<T, double>::value
|| std::is_same<T, const double>::value
> (const_name("numpy.float") + const_name<sizeof(T) * 8>(),
const_name("numpy.longdouble"));
};
template <typename T>
struct npy_format_descriptor_name<T, enable_if_t<is_complex<T>::value>> {
static constexpr auto name = const_name < std::is_same<typename T::value_type, float>::value
|| std::is_same<typename T::value_type, const float>::value
|| std::is_same<typename T::value_type, double>::value
|| std::is_same<typename T::value_type, const double>::value
> (const_name("numpy.complex")
+ const_name<sizeof(typename T::value_type) * 16>(),
const_name("numpy.longcomplex"));
};
template <typename T>
struct npy_format_descriptor<
T,

View File

@@ -159,6 +159,7 @@ set(PYBIND11_TEST_FILES
test_native_enum
test_numpy_array
test_numpy_dtypes
test_numpy_scalars
test_numpy_vectorize
test_opaque_types
test_operator_overloading

View File

@@ -0,0 +1,63 @@
/*
tests/test_numpy_scalars.cpp -- strict NumPy scalars
Copyright (c) 2021 Steve R. Sun
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#include <pybind11/numpy.h>
#include "pybind11_tests.h"
#include <complex>
#include <cstdint>
namespace py = pybind11;
namespace pybind11_test_numpy_scalars {
template <typename T>
struct add {
T x;
explicit add(T x) : x(x) {}
T operator()(T y) const { return static_cast<T>(x + y); }
};
template <typename T, typename F>
void register_test(py::module &m, const char *name, F &&func) {
m.def((std::string("test_") + name).c_str(),
[=](py::numpy_scalar<T> v) {
return std::make_tuple(name, py::make_scalar(static_cast<T>(func(v.value))));
},
py::arg("x"));
}
} // namespace pybind11_test_numpy_scalars
using namespace pybind11_test_numpy_scalars;
TEST_SUBMODULE(numpy_scalars, m) {
using cfloat = std::complex<float>;
using cdouble = std::complex<double>;
register_test<bool>(m, "bool", [](bool x) { return !x; });
register_test<int8_t>(m, "int8", add<int8_t>(-8));
register_test<int16_t>(m, "int16", add<int16_t>(-16));
register_test<int32_t>(m, "int32", add<int32_t>(-32));
register_test<int64_t>(m, "int64", add<int64_t>(-64));
register_test<uint8_t>(m, "uint8", add<uint8_t>(8));
register_test<uint16_t>(m, "uint16", add<uint16_t>(16));
register_test<uint32_t>(m, "uint32", add<uint32_t>(32));
register_test<uint64_t>(m, "uint64", add<uint64_t>(64));
register_test<float>(m, "float32", add<float>(0.125f));
register_test<double>(m, "float64", add<double>(0.25f));
register_test<cfloat>(m, "complex64", add<cfloat>({0, -0.125f}));
register_test<cdouble>(m, "complex128", add<cdouble>({0, -0.25f}));
m.def("test_eq",
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a == b; });
m.def("test_ne",
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a != b; });
}

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
import pytest
from pybind11_tests import numpy_scalars as m
np = pytest.importorskip("numpy")
NPY_SCALAR_TYPES = {
np.bool_: False,
np.int8: -7,
np.int16: -15,
np.int32: -31,
np.int64: -63,
np.uint8: 9,
np.uint16: 17,
np.uint32: 33,
np.uint64: 65,
np.single: 1.125,
np.double: 1.25,
np.complex64: 1 - 0.125j,
np.complex128: 1 - 0.25j,
}
ALL_SCALAR_TYPES = tuple(NPY_SCALAR_TYPES.keys()) + (int, bool, float, bytes, str)
@pytest.mark.parametrize(
("npy_scalar_type", "expected_value"), NPY_SCALAR_TYPES.items()
)
def test_numpy_scalars(npy_scalar_type, expected_value):
tpnm = npy_scalar_type.__name__.rstrip("_")
test_tpnm = getattr(m, "test_" + tpnm)
assert (
test_tpnm.__doc__
== f"test_{tpnm}(x: numpy.{tpnm}) -> tuple[str, numpy.{tpnm}]\n"
)
for tp in ALL_SCALAR_TYPES:
value = tp(1)
if tp is npy_scalar_type:
result_tpnm, result_value = test_tpnm(value)
assert result_tpnm == tpnm
assert isinstance(result_value, npy_scalar_type)
assert result_value == tp(expected_value)
else:
with pytest.raises(TypeError):
test_tpnm(value)
def test_eq_ne():
assert m.test_eq(np.int32(3), np.int32(3))
assert not m.test_eq(np.int32(3), np.int32(5))
assert not m.test_ne(np.int32(3), np.int32(3))
assert m.test_ne(np.int32(3), np.int32(5))