mirror of
https://github.com/pybind/pybind11.git
synced 2026-05-13 01:36:21 +00:00
feat: vectorize functions with void return type (#1969)
* Allow function/functor passed to py::vectorize to return void * Stealing @sizmailov's test and fixing unused argument warning * Add missing std::move() RVO doesn't work here because function return type is different from actual returned type * remove extra EOL * docs: add a few details * chore: pre-commit autoupdate * Remove array_iterator, array_begin, and array_end (in detail namespace) Co-authored-by: Sergei Izmailov <sergei.a.izmailov@gmail.com> Co-authored-by: Henry Schreiner <henryschreineriii@gmail.com>
This commit is contained in:
@@ -1274,19 +1274,6 @@ private:
|
||||
|
||||
#endif // __CLION_IDE__
|
||||
|
||||
template <class T>
|
||||
using array_iterator = typename std::add_pointer<T>::type;
|
||||
|
||||
template <class T>
|
||||
array_iterator<T> array_begin(const buffer_info& buffer) {
|
||||
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
array_iterator<T> array_end(const buffer_info& buffer) {
|
||||
return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
|
||||
}
|
||||
|
||||
class common_iterator {
|
||||
public:
|
||||
using container_type = std::vector<ssize_t>;
|
||||
@@ -1486,6 +1473,56 @@ struct vectorize_arg {
|
||||
using type = conditional_t<vectorize, array_t<remove_cv_t<call_type>, array::forcecast>, T>;
|
||||
};
|
||||
|
||||
|
||||
// py::vectorize when a return type is present
|
||||
template <typename Func, typename Return, typename... Args>
|
||||
struct vectorize_returned_array {
|
||||
using Type = array_t<Return>;
|
||||
|
||||
static Type create(broadcast_trivial trivial, const std::vector<ssize_t> &shape) {
|
||||
if (trivial == broadcast_trivial::f_trivial)
|
||||
return array_t<Return, array::f_style>(shape);
|
||||
else
|
||||
return array_t<Return>(shape);
|
||||
}
|
||||
|
||||
static Return *mutable_data(Type &array) {
|
||||
return array.mutable_data();
|
||||
}
|
||||
|
||||
static Return call(Func &f, Args &... args) {
|
||||
return f(args...);
|
||||
}
|
||||
|
||||
static void call(Return *out, size_t i, Func &f, Args &... args) {
|
||||
out[i] = f(args...);
|
||||
}
|
||||
};
|
||||
|
||||
// py::vectorize when a return type is not present
|
||||
template <typename Func, typename... Args>
|
||||
struct vectorize_returned_array<Func, void, Args...> {
|
||||
using Type = none;
|
||||
|
||||
static Type create(broadcast_trivial, const std::vector<ssize_t> &) {
|
||||
return none();
|
||||
}
|
||||
|
||||
static void *mutable_data(Type &) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static detail::void_type call(Func &f, Args &... args) {
|
||||
f(args...);
|
||||
return {};
|
||||
}
|
||||
|
||||
static void call(void *, size_t, Func &f, Args &... args) {
|
||||
f(args...);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Func, typename Return, typename... Args>
|
||||
struct vectorize_helper {
|
||||
|
||||
@@ -1520,6 +1557,8 @@ private:
|
||||
using arg_call_types = std::tuple<typename vectorize_arg<Args>::call_type...>;
|
||||
template <size_t Index> using param_n_t = typename std::tuple_element<Index, arg_call_types>::type;
|
||||
|
||||
using returned_array = vectorize_returned_array<Func, Return, Args...>;
|
||||
|
||||
// Runs a vectorized function given arguments tuple and three index sequences:
|
||||
// - Index is the full set of 0 ... (N-1) argument indices;
|
||||
// - VIndex is the subset of argument indices with vectorized parameters, letting us access
|
||||
@@ -1551,20 +1590,19 @@ private:
|
||||
// not wrapped in an array).
|
||||
if (size == 1 && ndim == 0) {
|
||||
PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr);
|
||||
return cast(f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...));
|
||||
return cast(returned_array::call(f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...));
|
||||
}
|
||||
|
||||
array_t<Return> result;
|
||||
if (trivial == broadcast_trivial::f_trivial) result = array_t<Return, array::f_style>(shape);
|
||||
else result = array_t<Return>(shape);
|
||||
auto result = returned_array::create(trivial, shape);
|
||||
|
||||
if (size == 0) return std::move(result);
|
||||
|
||||
/* Call the function */
|
||||
auto mutable_data = returned_array::mutable_data(result);
|
||||
if (trivial == broadcast_trivial::non_trivial)
|
||||
apply_broadcast(buffers, params, result, i_seq, vi_seq, bi_seq);
|
||||
apply_broadcast(buffers, params, mutable_data, size, shape, i_seq, vi_seq, bi_seq);
|
||||
else
|
||||
apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq);
|
||||
apply_trivial(buffers, params, mutable_data, size, i_seq, vi_seq, bi_seq);
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
@@ -1587,7 +1625,7 @@ private:
|
||||
}};
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
out[i] = f(*reinterpret_cast<param_n_t<Index> *>(params[Index])...);
|
||||
returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(params[Index])...);
|
||||
for (auto &x : vecparams) x.first += x.second;
|
||||
}
|
||||
}
|
||||
@@ -1595,19 +1633,18 @@ private:
|
||||
template <size_t... Index, size_t... VIndex, size_t... BIndex>
|
||||
void apply_broadcast(std::array<buffer_info, NVectorized> &buffers,
|
||||
std::array<void *, N> ¶ms,
|
||||
array_t<Return> &output_array,
|
||||
Return *out,
|
||||
size_t size,
|
||||
const std::vector<ssize_t> &output_shape,
|
||||
index_sequence<Index...>, index_sequence<VIndex...>, index_sequence<BIndex...>) {
|
||||
|
||||
buffer_info output = output_array.request();
|
||||
multi_array_iterator<NVectorized> input_iter(buffers, output.shape);
|
||||
multi_array_iterator<NVectorized> input_iter(buffers, output_shape);
|
||||
|
||||
for (array_iterator<Return> iter = array_begin<Return>(output), end = array_end<Return>(output);
|
||||
iter != end;
|
||||
++iter, ++input_iter) {
|
||||
for (size_t i = 0; i < size; ++i, ++input_iter) {
|
||||
PYBIND11_EXPAND_SIDE_EFFECTS((
|
||||
params[VIndex] = input_iter.template data<BIndex>()
|
||||
));
|
||||
*iter = f(*reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
|
||||
returned_array::call(out, i, f, *reinterpret_cast<param_n_t<Index> *>(std::get<Index>(params))...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user