array: add unchecked access via proxy object

This adds bounds-unchecked access to arrays through a `a.unchecked<Type,
Dimensions>()` method.  (For `array_t<T>`, the `Type` template parameter
is omitted).  The mutable version (which requires the array have the
`writeable` flag) is available as `a.mutable_unchecked<...>()`.

Specifying the Dimensions as a template parameter allows storage of an
std::array; having the strides and sizes stored that way (as opposed to
storing a copy of the array's strides/shape pointers) allows the
compiler to make significant optimizations of the shape() method that it
can't make with a pointer; testing with nested loops of the form:

    for (size_t i0 = 0; i0 < r.shape(0); i0++)
        for (size_t i1 = 0; i1 < r.shape(1); i1++)
            ...
                r(i0, i1, ...) += 1;

over a 10 million element array gives around a 25% speedup (versus using
a pointer) for the 1D case, 33% for 2D, and runs more than twice as fast
with a 5D array.
This commit is contained in:
Jason Rhinelander
2017-03-19 01:14:23 -03:00
parent 0d765f4a7c
commit 423a49b8be
4 changed files with 217 additions and 7 deletions

View File

@@ -35,6 +35,9 @@
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
NAMESPACE_BEGIN(pybind11)
class array; // Forward declaration
NAMESPACE_BEGIN(detail)
template <typename type, typename SFINAE = void> struct npy_format_descriptor;
@@ -232,6 +235,78 @@ template <typename T> using is_pod_struct = all_of<
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;
template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; }
template <size_t Dim = 0, typename Strides, typename... Ix>
size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
}
/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through
* the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.
*/
template <typename T, size_t Dims>
class unchecked_reference {
protected:
const unsigned char *data_;
// Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
// make large performance gains on big, nested loops.
std::array<size_t, Dims> shape_, strides_;
friend class pybind11::array;
unchecked_reference(const void *data, const size_t *shape, const size_t *strides)
: data_{reinterpret_cast<const unsigned char *>(data)} {
for (size_t i = 0; i < Dims; i++) {
shape_[i] = shape[i];
strides_[i] = strides[i];
}
}
public:
/** Unchecked const reference access to data at the given indices. Omiting trailing indices
* is equivalent to specifying them as 0.
*/
template <typename... Ix> const T& operator()(Ix... index) const {
static_assert(sizeof...(Ix) <= Dims, "Invalid number of indices for unchecked array reference");
return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t{index}...));
}
/** Unchecked const reference access to data; this operator only participates if the reference
* is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
*/
template <size_t D = Dims, typename = enable_if_t<D == 1>>
const T &operator[](size_t index) const { return operator()(index); }
/// Returns the shape (i.e. size) of dimension `dim`
size_t shape(size_t dim) const { return shape_[dim]; }
/// Returns the number of dimensions of the array
constexpr static size_t ndim() { return Dims; }
};
template <typename T, size_t Dims>
class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
friend class pybind11::array;
using ConstBase = unchecked_reference<T, Dims>;
using ConstBase::ConstBase;
public:
/// Mutable, unchecked access to data at the given indices.
template <typename... Ix> T& operator()(Ix... index) {
static_assert(sizeof...(Ix) == Dims, "Invalid number of indices for unchecked array reference");
return const_cast<T &>(ConstBase::operator()(index...));
}
/** Mutable, unchecked access data at the given index; this operator only participates if the
* reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
*/
template <size_t D = Dims, typename = enable_if_t<D == 1>>
T &operator[](size_t index) { return operator()(index); }
};
template <typename T, size_t Dim>
struct type_caster<unchecked_reference<T, Dim>> {
static_assert(Dim == (size_t) -1 /* always fail */, "unchecked array proxy object is not castable");
};
template <typename T, size_t Dim>
struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
NAMESPACE_END(detail)
class dtype : public object {
@@ -500,6 +575,31 @@ public:
return offset_at(index...) / itemsize();
}
/** Returns a proxy object that provides access to the array's data without bounds or
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
* care: the array must not be destroyed or reshaped for the duration of the returned object,
* and the caller must take care not to access invalid dimensions or dimension indices.
*/
template <typename T, size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
if (ndim() != Dims)
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
"; expected " + std::to_string(Dims));
return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides());
}
/** Returns a proxy object that provides const access to the array's data without bounds or
* dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
* underlying array have the `writable` flag. Use with care: the array must not be destroyed or
* reshaped for the duration of the returned object, and the caller must take care not to access
* invalid dimensions or dimension indices.
*/
template <typename T, size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
if (ndim() != Dims)
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
"; expected " + std::to_string(Dims));
return detail::unchecked_reference<T, Dims>(data(), shape(), strides());
}
/// Return a new view with all of the dimensions of length 1 removed
array squeeze() {
auto& api = detail::npy_api::get();
@@ -525,15 +625,9 @@ protected:
template<typename... Ix> size_t byte_offset(Ix... index) const {
check_dimensions(index...);
return byte_offset_unsafe(index...);
return detail::byte_offset_unsafe(strides(), size_t{index}...);
}
template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
}
template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }
void check_writeable() const {
if (!writeable())
throw std::domain_error("array is not writeable");
@@ -637,6 +731,25 @@ public:
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
}
/** Returns a proxy object that provides access to the array's data without bounds or
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
* care: the array must not be destroyed or reshaped for the duration of the returned object,
* and the caller must take care not to access invalid dimensions or dimension indices.
*/
template <size_t Dims> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
return array::mutable_unchecked<T, Dims>();
}
/** Returns a proxy object that provides const access to the array's data without bounds or
* dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
* array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
* for the duration of the returned object, and the caller must take care not to access invalid
* dimensions or dimension indices.
*/
template <size_t Dims> detail::unchecked_reference<T, Dims> unchecked() const {
return array::unchecked<T, Dims>();
}
/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
/// it). In case of an error, nullptr is returned and the Python error is cleared.
static array_t ensure(handle h) {