feat: numpy scalars (#5726)

This commit is contained in:
Henry Schreiner
2025-06-18 19:40:31 -04:00
committed by GitHub
parent c60c14991d
commit cf3d1a75a2
5 changed files with 319 additions and 36 deletions

View File

@@ -159,6 +159,7 @@ set(PYBIND11_TEST_FILES
test_native_enum
test_numpy_array
test_numpy_dtypes
test_numpy_scalars
test_numpy_vectorize
test_opaque_types
test_operator_overloading

View File

@@ -0,0 +1,63 @@
/*
tests/test_numpy_scalars.cpp -- strict NumPy scalars
Copyright (c) 2021 Steve R. Sun
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#include <pybind11/numpy.h>
#include "pybind11_tests.h"
#include <complex>
#include <cstdint>
namespace py = pybind11;
namespace pybind11_test_numpy_scalars {
template <typename T>
struct add {
T x;
explicit add(T x) : x(x) {}
T operator()(T y) const { return static_cast<T>(x + y); }
};
template <typename T, typename F>
void register_test(py::module &m, const char *name, F &&func) {
m.def((std::string("test_") + name).c_str(),
[=](py::numpy_scalar<T> v) {
return std::make_tuple(name, py::make_scalar(static_cast<T>(func(v.value))));
},
py::arg("x"));
}
} // namespace pybind11_test_numpy_scalars
using namespace pybind11_test_numpy_scalars;
TEST_SUBMODULE(numpy_scalars, m) {
using cfloat = std::complex<float>;
using cdouble = std::complex<double>;
register_test<bool>(m, "bool", [](bool x) { return !x; });
register_test<int8_t>(m, "int8", add<int8_t>(-8));
register_test<int16_t>(m, "int16", add<int16_t>(-16));
register_test<int32_t>(m, "int32", add<int32_t>(-32));
register_test<int64_t>(m, "int64", add<int64_t>(-64));
register_test<uint8_t>(m, "uint8", add<uint8_t>(8));
register_test<uint16_t>(m, "uint16", add<uint16_t>(16));
register_test<uint32_t>(m, "uint32", add<uint32_t>(32));
register_test<uint64_t>(m, "uint64", add<uint64_t>(64));
register_test<float>(m, "float32", add<float>(0.125f));
register_test<double>(m, "float64", add<double>(0.25f));
register_test<cfloat>(m, "complex64", add<cfloat>({0, -0.125f}));
register_test<cdouble>(m, "complex128", add<cdouble>({0, -0.25f}));
m.def("test_eq",
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a == b; });
m.def("test_ne",
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a != b; });
}

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
import pytest
from pybind11_tests import numpy_scalars as m
np = pytest.importorskip("numpy")
NPY_SCALAR_TYPES = {
np.bool_: False,
np.int8: -7,
np.int16: -15,
np.int32: -31,
np.int64: -63,
np.uint8: 9,
np.uint16: 17,
np.uint32: 33,
np.uint64: 65,
np.single: 1.125,
np.double: 1.25,
np.complex64: 1 - 0.125j,
np.complex128: 1 - 0.25j,
}
ALL_SCALAR_TYPES = tuple(NPY_SCALAR_TYPES.keys()) + (int, bool, float, bytes, str)
@pytest.mark.parametrize(
("npy_scalar_type", "expected_value"), NPY_SCALAR_TYPES.items()
)
def test_numpy_scalars(npy_scalar_type, expected_value):
tpnm = npy_scalar_type.__name__.rstrip("_")
test_tpnm = getattr(m, "test_" + tpnm)
assert (
test_tpnm.__doc__
== f"test_{tpnm}(x: numpy.{tpnm}) -> tuple[str, numpy.{tpnm}]\n"
)
for tp in ALL_SCALAR_TYPES:
value = tp(1)
if tp is npy_scalar_type:
result_tpnm, result_value = test_tpnm(value)
assert result_tpnm == tpnm
assert isinstance(result_value, npy_scalar_type)
assert result_value == tp(expected_value)
else:
with pytest.raises(TypeError):
test_tpnm(value)
def test_eq_ne():
assert m.test_eq(np.int32(3), np.int32(3))
assert not m.test_eq(np.int32(3), np.int32(5))
assert not m.test_ne(np.int32(3), np.int32(3))
assert m.test_ne(np.int32(3), np.int32(5))