mirror of
https://github.com/pybind/pybind11.git
synced 2026-04-20 14:59:27 +00:00
feat: numpy scalars (#5726)
This commit is contained in:
@@ -159,6 +159,7 @@ set(PYBIND11_TEST_FILES
|
||||
test_native_enum
|
||||
test_numpy_array
|
||||
test_numpy_dtypes
|
||||
test_numpy_scalars
|
||||
test_numpy_vectorize
|
||||
test_opaque_types
|
||||
test_operator_overloading
|
||||
|
||||
63
tests/test_numpy_scalars.cpp
Normal file
63
tests/test_numpy_scalars.cpp
Normal file
@@ -0,0 +1,63 @@
|
||||
/*
|
||||
tests/test_numpy_scalars.cpp -- strict NumPy scalars
|
||||
|
||||
Copyright (c) 2021 Steve R. Sun
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
#include "pybind11_tests.h"
|
||||
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace pybind11_test_numpy_scalars {
|
||||
|
||||
template <typename T>
|
||||
struct add {
|
||||
T x;
|
||||
explicit add(T x) : x(x) {}
|
||||
T operator()(T y) const { return static_cast<T>(x + y); }
|
||||
};
|
||||
|
||||
template <typename T, typename F>
|
||||
void register_test(py::module &m, const char *name, F &&func) {
|
||||
m.def((std::string("test_") + name).c_str(),
|
||||
[=](py::numpy_scalar<T> v) {
|
||||
return std::make_tuple(name, py::make_scalar(static_cast<T>(func(v.value))));
|
||||
},
|
||||
py::arg("x"));
|
||||
}
|
||||
|
||||
} // namespace pybind11_test_numpy_scalars
|
||||
|
||||
using namespace pybind11_test_numpy_scalars;
|
||||
|
||||
TEST_SUBMODULE(numpy_scalars, m) {
|
||||
using cfloat = std::complex<float>;
|
||||
using cdouble = std::complex<double>;
|
||||
|
||||
register_test<bool>(m, "bool", [](bool x) { return !x; });
|
||||
register_test<int8_t>(m, "int8", add<int8_t>(-8));
|
||||
register_test<int16_t>(m, "int16", add<int16_t>(-16));
|
||||
register_test<int32_t>(m, "int32", add<int32_t>(-32));
|
||||
register_test<int64_t>(m, "int64", add<int64_t>(-64));
|
||||
register_test<uint8_t>(m, "uint8", add<uint8_t>(8));
|
||||
register_test<uint16_t>(m, "uint16", add<uint16_t>(16));
|
||||
register_test<uint32_t>(m, "uint32", add<uint32_t>(32));
|
||||
register_test<uint64_t>(m, "uint64", add<uint64_t>(64));
|
||||
register_test<float>(m, "float32", add<float>(0.125f));
|
||||
register_test<double>(m, "float64", add<double>(0.25f));
|
||||
register_test<cfloat>(m, "complex64", add<cfloat>({0, -0.125f}));
|
||||
register_test<cdouble>(m, "complex128", add<cdouble>({0, -0.25f}));
|
||||
|
||||
m.def("test_eq",
|
||||
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a == b; });
|
||||
m.def("test_ne",
|
||||
[](py::numpy_scalar<int32_t> a, py::numpy_scalar<int32_t> b) { return a != b; });
|
||||
}
|
||||
54
tests/test_numpy_scalars.py
Normal file
54
tests/test_numpy_scalars.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from pybind11_tests import numpy_scalars as m
|
||||
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
NPY_SCALAR_TYPES = {
|
||||
np.bool_: False,
|
||||
np.int8: -7,
|
||||
np.int16: -15,
|
||||
np.int32: -31,
|
||||
np.int64: -63,
|
||||
np.uint8: 9,
|
||||
np.uint16: 17,
|
||||
np.uint32: 33,
|
||||
np.uint64: 65,
|
||||
np.single: 1.125,
|
||||
np.double: 1.25,
|
||||
np.complex64: 1 - 0.125j,
|
||||
np.complex128: 1 - 0.25j,
|
||||
}
|
||||
|
||||
ALL_SCALAR_TYPES = tuple(NPY_SCALAR_TYPES.keys()) + (int, bool, float, bytes, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("npy_scalar_type", "expected_value"), NPY_SCALAR_TYPES.items()
|
||||
)
|
||||
def test_numpy_scalars(npy_scalar_type, expected_value):
|
||||
tpnm = npy_scalar_type.__name__.rstrip("_")
|
||||
test_tpnm = getattr(m, "test_" + tpnm)
|
||||
assert (
|
||||
test_tpnm.__doc__
|
||||
== f"test_{tpnm}(x: numpy.{tpnm}) -> tuple[str, numpy.{tpnm}]\n"
|
||||
)
|
||||
for tp in ALL_SCALAR_TYPES:
|
||||
value = tp(1)
|
||||
if tp is npy_scalar_type:
|
||||
result_tpnm, result_value = test_tpnm(value)
|
||||
assert result_tpnm == tpnm
|
||||
assert isinstance(result_value, npy_scalar_type)
|
||||
assert result_value == tp(expected_value)
|
||||
else:
|
||||
with pytest.raises(TypeError):
|
||||
test_tpnm(value)
|
||||
|
||||
|
||||
def test_eq_ne():
|
||||
assert m.test_eq(np.int32(3), np.int32(3))
|
||||
assert not m.test_eq(np.int32(3), np.int32(5))
|
||||
assert not m.test_ne(np.int32(3), np.int32(3))
|
||||
assert m.test_ne(np.int32(3), np.int32(5))
|
||||
Reference in New Issue
Block a user