mirror of
https://github.com/pybind/pybind11.git
synced 2026-04-26 01:39:16 +00:00
Add helpers to array that return the size and strides as a std::span (#5974)
* Add helper functions to pybind11::array to return the shape and strides as a std::span. These functions are hidden with macros unless PYBIND11_CPP20 is defined and the <span> include has been found. * style: pre-commit fixes * tests: Add unit tests for shape_span() and strides_span() Add comprehensive unit tests for the new std::span helper functions: - Test 0D, 1D, 2D, and 3D arrays - Verify spans match regular shape()/strides() methods - Test that spans can be used to construct new arrays - Tests are conditionally compiled only when PYBIND11_HAS_SPAN is defined * Use __cpp_lib_span feature test macro instead of __has_include Replace __has_include(<span>) check with __cpp_lib_span feature test macro to resolve ambiguity where some pre-C++20 systems might have a global header called <span> that isn't the C++20 std::span. The check is moved after <version> is included, consistent with how __cpp_lib_char8_t is handled. Co-authored-by: Cursor <cursoragent@cursor.com> * Fix: Use py::ssize_t instead of ssize_t in span tests On Windows/MSVC, ssize_t is not available in the standard namespace without proper includes. Use py::ssize_t (the pybind11 typedef) instead to ensure cross-platform compatibility. Fixes compilation errors on: - Windows/MSVC 2022 (C++20) - GCC 10 (C++20) Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf W. Grosse-Kunstleve <rgrossekunst@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -252,6 +252,10 @@
|
|||||||
# define PYBIND11_HAS_U8STRING 1
|
# define PYBIND11_HAS_U8STRING 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(PYBIND11_CPP20) && defined(__cpp_lib_span) && __cpp_lib_span >= 202002L
|
||||||
|
# define PYBIND11_HAS_SPAN 1
|
||||||
|
#endif
|
||||||
|
|
||||||
// See description of PR #4246:
|
// See description of PR #4246:
|
||||||
#if !defined(PYBIND11_NO_ASSERT_GIL_HELD_INCREF_DECREF) && !defined(NDEBUG) \
|
#if !defined(PYBIND11_NO_ASSERT_GIL_HELD_INCREF_DECREF) && !defined(NDEBUG) \
|
||||||
&& !defined(PYPY_VERSION) && !defined(PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF)
|
&& !defined(PYPY_VERSION) && !defined(PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF)
|
||||||
|
|||||||
@@ -29,6 +29,10 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#ifdef PYBIND11_HAS_SPAN
|
||||||
|
# include <span>
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(PYBIND11_NUMPY_1_ONLY)
|
#if defined(PYBIND11_NUMPY_1_ONLY)
|
||||||
# error "PYBIND11_NUMPY_1_ONLY is no longer supported (see PR #5595)."
|
# error "PYBIND11_NUMPY_1_ONLY is no longer supported (see PR #5595)."
|
||||||
#endif
|
#endif
|
||||||
@@ -1143,6 +1147,13 @@ public:
|
|||||||
/// Dimensions of the array
|
/// Dimensions of the array
|
||||||
const ssize_t *shape() const { return detail::array_proxy(m_ptr)->dimensions; }
|
const ssize_t *shape() const { return detail::array_proxy(m_ptr)->dimensions; }
|
||||||
|
|
||||||
|
#ifdef PYBIND11_HAS_SPAN
|
||||||
|
/// Dimensions of the array as a span
|
||||||
|
std::span<const ssize_t, std::dynamic_extent> shape_span() const {
|
||||||
|
return std::span(shape(), static_cast<std::size_t>(ndim()));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
/// Dimension along a given axis
|
/// Dimension along a given axis
|
||||||
ssize_t shape(ssize_t dim) const {
|
ssize_t shape(ssize_t dim) const {
|
||||||
if (dim >= ndim()) {
|
if (dim >= ndim()) {
|
||||||
@@ -1154,6 +1165,13 @@ public:
|
|||||||
/// Strides of the array
|
/// Strides of the array
|
||||||
const ssize_t *strides() const { return detail::array_proxy(m_ptr)->strides; }
|
const ssize_t *strides() const { return detail::array_proxy(m_ptr)->strides; }
|
||||||
|
|
||||||
|
#ifdef PYBIND11_HAS_SPAN
|
||||||
|
/// Strides of the array as a span
|
||||||
|
std::span<const ssize_t, std::dynamic_extent> strides_span() const {
|
||||||
|
return std::span(strides(), static_cast<std::size_t>(ndim()));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
/// Stride along a given axis
|
/// Stride along a given axis
|
||||||
ssize_t strides(ssize_t dim) const {
|
ssize_t strides(ssize_t dim) const {
|
||||||
if (dim >= ndim()) {
|
if (dim >= ndim()) {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
// Size / dtype checks.
|
// Size / dtype checks.
|
||||||
struct DtypeCheck {
|
struct DtypeCheck {
|
||||||
@@ -246,6 +247,22 @@ TEST_SUBMODULE(numpy_array, sm) {
|
|||||||
sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
|
sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
|
||||||
sm.def("owndata", [](const arr &a) { return a.owndata(); });
|
sm.def("owndata", [](const arr &a) { return a.owndata(); });
|
||||||
|
|
||||||
|
#ifdef PYBIND11_HAS_SPAN
|
||||||
|
// test_shape_strides_span
|
||||||
|
sm.def("shape_span", [](const arr &a) {
|
||||||
|
auto span = a.shape_span();
|
||||||
|
return std::vector<py::ssize_t>(span.begin(), span.end());
|
||||||
|
});
|
||||||
|
sm.def("strides_span", [](const arr &a) {
|
||||||
|
auto span = a.strides_span();
|
||||||
|
return std::vector<py::ssize_t>(span.begin(), span.end());
|
||||||
|
});
|
||||||
|
// Test that spans can be used to construct new arrays
|
||||||
|
sm.def("array_from_spans", [](const arr &a) {
|
||||||
|
return py::array(a.dtype(), a.shape_span(), a.strides_span(), a.data(), a);
|
||||||
|
});
|
||||||
|
#endif
|
||||||
|
|
||||||
// test_index_offset
|
// test_index_offset
|
||||||
def_index_fn(index_at, const arr &);
|
def_index_fn(index_at, const arr &);
|
||||||
def_index_fn(index_at_t, const arr_t &);
|
def_index_fn(index_at_t, const arr_t &);
|
||||||
|
|||||||
@@ -68,6 +68,45 @@ def test_array_attributes():
|
|||||||
assert not m.owndata(a)
|
assert not m.owndata(a)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not hasattr(m, "shape_span"), reason="std::span not available")
|
||||||
|
def test_shape_strides_span():
|
||||||
|
# Test 0-dimensional array (scalar)
|
||||||
|
a = np.array(42, "f8")
|
||||||
|
assert m.ndim(a) == 0
|
||||||
|
assert m.shape_span(a) == []
|
||||||
|
assert m.strides_span(a) == []
|
||||||
|
|
||||||
|
# Test 1-dimensional array
|
||||||
|
a = np.array([1, 2, 3, 4], "u2")
|
||||||
|
assert m.ndim(a) == 1
|
||||||
|
assert m.shape_span(a) == [4]
|
||||||
|
assert m.strides_span(a) == [2]
|
||||||
|
|
||||||
|
# Test 2-dimensional array
|
||||||
|
a = np.array([[1, 2, 3], [4, 5, 6]], "u2").view()
|
||||||
|
a.flags.writeable = False
|
||||||
|
assert m.ndim(a) == 2
|
||||||
|
assert m.shape_span(a) == [2, 3]
|
||||||
|
assert m.strides_span(a) == [6, 2]
|
||||||
|
|
||||||
|
# Test 3-dimensional array
|
||||||
|
a = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], "i4")
|
||||||
|
assert m.ndim(a) == 3
|
||||||
|
assert m.shape_span(a) == [2, 2, 2]
|
||||||
|
# Verify spans match regular shape/strides
|
||||||
|
assert list(m.shape_span(a)) == list(m.shape(a))
|
||||||
|
assert list(m.strides_span(a)) == list(m.strides(a))
|
||||||
|
|
||||||
|
# Test that spans can be used to construct new arrays
|
||||||
|
original = np.array([[1, 2, 3], [4, 5, 6]], "f4")
|
||||||
|
new_array = m.array_from_spans(original)
|
||||||
|
assert new_array.shape == original.shape
|
||||||
|
assert new_array.strides == original.strides
|
||||||
|
assert new_array.dtype == original.dtype
|
||||||
|
# Verify data is shared (since we pass the same data pointer)
|
||||||
|
np.testing.assert_array_equal(new_array, original)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("args", "ret"), [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]
|
("args", "ret"), [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user