Fix make_key_iterator/make_value_iterator for prvalue iterators (#3348)

* Add a test showing a flaw in make_key_iterator/make_value_iterator

If the iterator dereference operator returns a value rather than a
reference (and that pair also does not *contain* references),
make_key_iterator and make_value_iterator will return a reference to a
temporary, causing a segfault.

* Fix make_key_iterator/make_value_iterator for prvalue iterators

If an iterator returns a pair<T1, T2> rather than a reference to a pair
or a pair of references, make_key_iterator and make_value_iterator would
return a reference to a temporary, typically leading to a segfault. This
is because the value category of member access to a prvalue is an
xvalue, not a prvalue, so decltype produces an rvalue reference type.
Fix the type calculation to handle this case.

I also removed some decltype parentheses that weren't needed, either
because the expression isn't one of the special cases for decltype or
because decltype was only used for SFINAE. Hopefully that makes the code
a bit more readable.

Closes #3347

* Attempt a workaround for nvcc
This commit is contained in:
Bruce Merry
2021-10-11 17:35:39 +02:00
committed by GitHub
parent 750e38dcfd
commit 8a7c266d26
3 changed files with 69 additions and 12 deletions

View File

@@ -38,6 +38,17 @@ bool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentine
return !(*it).first || !(*it).second;
}
/* Iterator where dereferencing returns prvalues instead of references. */
template<typename T>
class NonRefIterator {
const T* ptr_;
public:
explicit NonRefIterator(const T *ptr) : ptr_(ptr) {}
T operator*() const { return T(*ptr_); }
NonRefIterator& operator++() { ++ptr_; return *this; }
bool operator==(const NonRefIterator &other) const { return ptr_ == other.ptr_; }
};
class NonCopyableInt {
public:
explicit NonCopyableInt(int value) : value_(value) {}
@@ -331,7 +342,7 @@ TEST_SUBMODULE(sequences_and_iterators, m) {
py::class_<IntPairs>(m, "IntPairs")
.def(py::init<std::vector<std::pair<int, int>>>())
.def("nonzero", [](const IntPairs& s) {
return py::make_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
return py::make_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
}, py::keep_alive<0, 1>())
.def("nonzero_keys", [](const IntPairs& s) {
return py::make_key_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
@@ -340,6 +351,20 @@ TEST_SUBMODULE(sequences_and_iterators, m) {
return py::make_value_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
}, py::keep_alive<0, 1>())
// test iterator that returns values instead of references
.def("nonref", [](const IntPairs& s) {
return py::make_iterator(NonRefIterator<std::pair<int, int>>(s.begin()),
NonRefIterator<std::pair<int, int>>(s.end()));
}, py::keep_alive<0, 1>())
.def("nonref_keys", [](const IntPairs& s) {
return py::make_key_iterator(NonRefIterator<std::pair<int, int>>(s.begin()),
NonRefIterator<std::pair<int, int>>(s.end()));
}, py::keep_alive<0, 1>())
.def("nonref_values", [](const IntPairs& s) {
return py::make_value_iterator(NonRefIterator<std::pair<int, int>>(s.begin()),
NonRefIterator<std::pair<int, int>>(s.end()));
}, py::keep_alive<0, 1>())
// test single-argument make_iterator
.def("simple_iterator", [](IntPairs& self) {
return py::make_iterator(self);