mirror of
https://github.com/pybind/pybind11.git
synced 2026-05-12 01:10:34 +00:00
array: add unchecked access via proxy object
This adds bounds-unchecked access to arrays through a `a.unchecked<Type,
Dimensions>()` method. (For `array_t<T>`, the `Type` template parameter
is omitted). The mutable version (which requires the array have the
`writeable` flag) is available as `a.mutable_unchecked<...>()`.
Specifying the Dimensions as a template parameter allows storage of an
std::array; having the strides and sizes stored that way (as opposed to
storing a copy of the array's strides/shape pointers) allows the
compiler to make significant optimizations of the shape() method that it
can't make with a pointer; testing with nested loops of the form:
for (size_t i0 = 0; i0 < r.shape(0); i0++)
for (size_t i1 = 0; i1 < r.shape(1); i1++)
...
r(i0, i1, ...) += 1;
over a 10 million element array gives around a 25% speedup (versus using
a pointer) for the 1D case, 33% for 2D, and runs more than twice as fast
with a 5D array.
This commit is contained in:
@@ -35,6 +35,9 @@
|
||||
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
|
||||
class array; // Forward declaration
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
|
||||
|
||||
@@ -232,6 +235,78 @@ template <typename T> using is_pod_struct = all_of<
|
||||
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
|
||||
>;
|
||||
|
||||
template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; }
|
||||
template <size_t Dim = 0, typename Strides, typename... Ix>
|
||||
size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
|
||||
return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
|
||||
}
|
||||
|
||||
/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through
|
||||
* the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.
|
||||
*/
|
||||
template <typename T, size_t Dims>
|
||||
class unchecked_reference {
|
||||
protected:
|
||||
const unsigned char *data_;
|
||||
// Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
|
||||
// make large performance gains on big, nested loops.
|
||||
std::array<size_t, Dims> shape_, strides_;
|
||||
|
||||
friend class pybind11::array;
|
||||
unchecked_reference(const void *data, const size_t *shape, const size_t *strides)
|
||||
: data_{reinterpret_cast<const unsigned char *>(data)} {
|
||||
for (size_t i = 0; i < Dims; i++) {
|
||||
shape_[i] = shape[i];
|
||||
strides_[i] = strides[i];
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/** Unchecked const reference access to data at the given indices. Omiting trailing indices
|
||||
* is equivalent to specifying them as 0.
|
||||
*/
|
||||
template <typename... Ix> const T& operator()(Ix... index) const {
|
||||
static_assert(sizeof...(Ix) <= Dims, "Invalid number of indices for unchecked array reference");
|
||||
return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t{index}...));
|
||||
}
|
||||
/** Unchecked const reference access to data; this operator only participates if the reference
|
||||
* is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
|
||||
*/
|
||||
template <size_t D = Dims, typename = enable_if_t<D == 1>>
|
||||
const T &operator[](size_t index) const { return operator()(index); }
|
||||
|
||||
/// Returns the shape (i.e. size) of dimension `dim`
|
||||
size_t shape(size_t dim) const { return shape_[dim]; }
|
||||
|
||||
/// Returns the number of dimensions of the array
|
||||
constexpr static size_t ndim() { return Dims; }
|
||||
};
|
||||
|
||||
template <typename T, size_t Dims>
|
||||
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
|
||||
friend class pybind11::array;
|
||||
using ConstBase = unchecked_reference<T, Dims>;
|
||||
using ConstBase::ConstBase;
|
||||
public:
|
||||
/// Mutable, unchecked access to data at the given indices.
|
||||
template <typename... Ix> T& operator()(Ix... index) {
|
||||
static_assert(sizeof...(Ix) == Dims, "Invalid number of indices for unchecked array reference");
|
||||
return const_cast<T &>(ConstBase::operator()(index...));
|
||||
}
|
||||
/** Mutable, unchecked access data at the given index; this operator only participates if the
|
||||
* reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
|
||||
*/
|
||||
template <size_t D = Dims, typename = enable_if_t<D == 1>>
|
||||
T &operator[](size_t index) { return operator()(index); }
|
||||
};
|
||||
|
||||
template <typename T, size_t Dim>
|
||||
struct type_caster<unchecked_reference<T, Dim>> {
|
||||
static_assert(Dim == (size_t) -1 /* always fail */, "unchecked array proxy object is not castable");
|
||||
};
|
||||
template <typename T, size_t Dim>
|
||||
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
class dtype : public object {
|
||||
@@ -500,6 +575,31 @@ public:
|
||||
return offset_at(index...) / itemsize();
|
||||
}
|
||||
|
||||
/** Returns a proxy object that provides access to the array's data without bounds or
|
||||
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
||||
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
||||
* and the caller must take care not to access invalid dimensions or dimension indices.
|
||||
*/
|
||||
template <typename T, size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||
if (ndim() != Dims)
|
||||
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
|
||||
"; expected " + std::to_string(Dims));
|
||||
return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides());
|
||||
}
|
||||
|
||||
/** Returns a proxy object that provides const access to the array's data without bounds or
|
||||
* dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
|
||||
* underlying array have the `writable` flag. Use with care: the array must not be destroyed or
|
||||
* reshaped for the duration of the returned object, and the caller must take care not to access
|
||||
* invalid dimensions or dimension indices.
|
||||
*/
|
||||
template <typename T, size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||
if (ndim() != Dims)
|
||||
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
|
||||
"; expected " + std::to_string(Dims));
|
||||
return detail::unchecked_reference<T, Dims>(data(), shape(), strides());
|
||||
}
|
||||
|
||||
/// Return a new view with all of the dimensions of length 1 removed
|
||||
array squeeze() {
|
||||
auto& api = detail::npy_api::get();
|
||||
@@ -525,15 +625,9 @@ protected:
|
||||
|
||||
template<typename... Ix> size_t byte_offset(Ix... index) const {
|
||||
check_dimensions(index...);
|
||||
return byte_offset_unsafe(index...);
|
||||
return detail::byte_offset_unsafe(strides(), size_t{index}...);
|
||||
}
|
||||
|
||||
template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
|
||||
return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
|
||||
}
|
||||
|
||||
template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
|
||||
|
||||
void check_writeable() const {
|
||||
if (!writeable())
|
||||
throw std::domain_error("array is not writeable");
|
||||
@@ -637,6 +731,25 @@ public:
|
||||
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
||||
}
|
||||
|
||||
/** Returns a proxy object that provides access to the array's data without bounds or
|
||||
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
|
||||
* care: the array must not be destroyed or reshaped for the duration of the returned object,
|
||||
* and the caller must take care not to access invalid dimensions or dimension indices.
|
||||
*/
|
||||
template <size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
|
||||
return array::mutable_unchecked<T, Dims>();
|
||||
}
|
||||
|
||||
/** Returns a proxy object that provides const access to the array's data without bounds or
|
||||
* dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
|
||||
* array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
|
||||
* for the duration of the returned object, and the caller must take care not to access invalid
|
||||
* dimensions or dimension indices.
|
||||
*/
|
||||
template <size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
|
||||
return array::unchecked<T, Dims>();
|
||||
}
|
||||
|
||||
/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
|
||||
/// it). In case of an error, nullptr is returned and the Python error is cleared.
|
||||
static array_t ensure(handle h) {
|
||||
|
||||
Reference in New Issue
Block a user