mirror of
https://github.com/pybind/pybind11.git
synced 2026-03-14 20:27:47 +00:00
Add helpers to array that return the size and strides as a std::span (#5974)
* Add helper functions to pybind11::array to return the shape and strides as a std::span. These functions are hidden with macros unless PYBIND11_CPP20 is defined and the <span> include has been found. * style: pre-commit fixes * tests: Add unit tests for shape_span() and strides_span() Add comprehensive unit tests for the new std::span helper functions: - Test 0D, 1D, 2D, and 3D arrays - Verify spans match regular shape()/strides() methods - Test that spans can be used to construct new arrays - Tests are conditionally compiled only when PYBIND11_HAS_SPAN is defined * Use __cpp_lib_span feature test macro instead of __has_include Replace __has_include(<span>) check with __cpp_lib_span feature test macro to resolve ambiguity where some pre-C++20 systems might have a global header called <span> that isn't the C++20 std::span. The check is moved after <version> is included, consistent with how __cpp_lib_char8_t is handled. Co-authored-by: Cursor <cursoragent@cursor.com> * Fix: Use py::ssize_t instead of ssize_t in span tests On Windows/MSVC, ssize_t is not available in the standard namespace without proper includes. Use py::ssize_t (the pybind11 typedef) instead to ensure cross-platform compatibility. Fixes compilation errors on: - Windows/MSVC 2022 (C++20) - GCC 10 (C++20) Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf W. Grosse-Kunstleve <rgrossekunst@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -252,6 +252,10 @@
|
||||
# define PYBIND11_HAS_U8STRING 1
|
||||
#endif
|
||||
|
||||
#if defined(PYBIND11_CPP20) && defined(__cpp_lib_span) && __cpp_lib_span >= 202002L
|
||||
# define PYBIND11_HAS_SPAN 1
|
||||
#endif
|
||||
|
||||
// See description of PR #4246:
|
||||
#if !defined(PYBIND11_NO_ASSERT_GIL_HELD_INCREF_DECREF) && !defined(NDEBUG) \
|
||||
&& !defined(PYPY_VERSION) && !defined(PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF)
|
||||
|
||||
@@ -29,6 +29,10 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#ifdef PYBIND11_HAS_SPAN
|
||||
# include <span>
|
||||
#endif
|
||||
|
||||
#if defined(PYBIND11_NUMPY_1_ONLY)
|
||||
# error "PYBIND11_NUMPY_1_ONLY is no longer supported (see PR #5595)."
|
||||
#endif
|
||||
@@ -1143,6 +1147,13 @@ public:
|
||||
/// Dimensions of the array
|
||||
const ssize_t *shape() const { return detail::array_proxy(m_ptr)->dimensions; }
|
||||
|
||||
#ifdef PYBIND11_HAS_SPAN
|
||||
/// Dimensions of the array as a span
|
||||
std::span<const ssize_t, std::dynamic_extent> shape_span() const {
|
||||
return std::span(shape(), static_cast<std::size_t>(ndim()));
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Dimension along a given axis
|
||||
ssize_t shape(ssize_t dim) const {
|
||||
if (dim >= ndim()) {
|
||||
@@ -1154,6 +1165,13 @@ public:
|
||||
/// Strides of the array
|
||||
const ssize_t *strides() const { return detail::array_proxy(m_ptr)->strides; }
|
||||
|
||||
#ifdef PYBIND11_HAS_SPAN
|
||||
/// Strides of the array as a span
|
||||
std::span<const ssize_t, std::dynamic_extent> strides_span() const {
|
||||
return std::span(strides(), static_cast<std::size_t>(ndim()));
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Stride along a given axis
|
||||
ssize_t strides(ssize_t dim) const {
|
||||
if (dim >= ndim()) {
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// Size / dtype checks.
|
||||
struct DtypeCheck {
|
||||
@@ -246,6 +247,22 @@ TEST_SUBMODULE(numpy_array, sm) {
|
||||
sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
|
||||
sm.def("owndata", [](const arr &a) { return a.owndata(); });
|
||||
|
||||
#ifdef PYBIND11_HAS_SPAN
|
||||
// test_shape_strides_span
|
||||
sm.def("shape_span", [](const arr &a) {
|
||||
auto span = a.shape_span();
|
||||
return std::vector<py::ssize_t>(span.begin(), span.end());
|
||||
});
|
||||
sm.def("strides_span", [](const arr &a) {
|
||||
auto span = a.strides_span();
|
||||
return std::vector<py::ssize_t>(span.begin(), span.end());
|
||||
});
|
||||
// Test that spans can be used to construct new arrays
|
||||
sm.def("array_from_spans", [](const arr &a) {
|
||||
return py::array(a.dtype(), a.shape_span(), a.strides_span(), a.data(), a);
|
||||
});
|
||||
#endif
|
||||
|
||||
// test_index_offset
|
||||
def_index_fn(index_at, const arr &);
|
||||
def_index_fn(index_at_t, const arr_t &);
|
||||
|
||||
@@ -68,6 +68,45 @@ def test_array_attributes():
|
||||
assert not m.owndata(a)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hasattr(m, "shape_span"), reason="std::span not available")
|
||||
def test_shape_strides_span():
|
||||
# Test 0-dimensional array (scalar)
|
||||
a = np.array(42, "f8")
|
||||
assert m.ndim(a) == 0
|
||||
assert m.shape_span(a) == []
|
||||
assert m.strides_span(a) == []
|
||||
|
||||
# Test 1-dimensional array
|
||||
a = np.array([1, 2, 3, 4], "u2")
|
||||
assert m.ndim(a) == 1
|
||||
assert m.shape_span(a) == [4]
|
||||
assert m.strides_span(a) == [2]
|
||||
|
||||
# Test 2-dimensional array
|
||||
a = np.array([[1, 2, 3], [4, 5, 6]], "u2").view()
|
||||
a.flags.writeable = False
|
||||
assert m.ndim(a) == 2
|
||||
assert m.shape_span(a) == [2, 3]
|
||||
assert m.strides_span(a) == [6, 2]
|
||||
|
||||
# Test 3-dimensional array
|
||||
a = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], "i4")
|
||||
assert m.ndim(a) == 3
|
||||
assert m.shape_span(a) == [2, 2, 2]
|
||||
# Verify spans match regular shape/strides
|
||||
assert list(m.shape_span(a)) == list(m.shape(a))
|
||||
assert list(m.strides_span(a)) == list(m.strides(a))
|
||||
|
||||
# Test that spans can be used to construct new arrays
|
||||
original = np.array([[1, 2, 3], [4, 5, 6]], "f4")
|
||||
new_array = m.array_from_spans(original)
|
||||
assert new_array.shape == original.shape
|
||||
assert new_array.strides == original.strides
|
||||
assert new_array.dtype == original.dtype
|
||||
# Verify data is shared (since we pass the same data pointer)
|
||||
np.testing.assert_array_equal(new_array, original)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("args", "ret"), [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user