mirror of
https://github.com/pybind/pybind11.git
synced 2026-05-13 17:56:02 +00:00
fix: avoid copy constructor instantiation in shared_ptr fallback cast (#6028)
* tests: add regressions for shared_ptr reference_internal fallback * fix: avoid copy constructor instantiation in shared_ptr fallback cast * Remove stray empty line * tests: rename PyTorch shared_ptr regression test files * refactor: add cast_non_owning helper for reference-like casts Name the non-owning generic cast path so callers do not have to rediscover that reference-like policies must pass null copy/move constructor callbacks. This keeps the shared_ptr reference_internal fallback self-documenting and points future maintainers toward the safe API. Made-with: Cursor * tests: guard deprecated-copy warning probes with __has_warning Use __has_warning for the Clang-only regression test so older compiler jobs skip unsupported warning groups instead of failing with -Wunknown-warning-option. A simple __clang_major__ >= 13 guard would be shorter, but it bakes in a version cutoff; __has_warning is slightly more verbose while being more robust to vendor builds, backports, and future packaging differences. Made-with: Cursor --------- Co-authored-by: Ralf W. Grosse-Kunstleve <rgrossekunst@nvidia.com>
This commit is contained in:
committed by
Ralf W. Grosse-Kunstleve
parent
ad5bc9e80e
commit
ab392bd845
@@ -1026,7 +1026,7 @@ public:
|
||||
}
|
||||
|
||||
if (parent) {
|
||||
return type_caster_base<type>::cast(
|
||||
return type_caster_generic::cast_non_owning(
|
||||
srcs, return_value_policy::reference_internal, parent);
|
||||
}
|
||||
|
||||
|
||||
@@ -1004,6 +1004,18 @@ public:
|
||||
return cast(srcs, policy, parent, copy_constructor, move_constructor, existing_holder);
|
||||
}
|
||||
|
||||
static handle cast_non_owning(const cast_sources &srcs,
|
||||
return_value_policy policy,
|
||||
handle parent,
|
||||
const void *existing_holder = nullptr) {
|
||||
// Reference-like policies alias an existing C++ object instead of creating
|
||||
// a new one, so copy/move constructor callbacks must remain null here.
|
||||
assert(policy == return_value_policy::reference
|
||||
|| policy == return_value_policy::reference_internal
|
||||
|| policy == return_value_policy::automatic_reference);
|
||||
return cast(srcs, policy, parent, nullptr, nullptr, existing_holder);
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE static handle cast(const cast_sources &srcs,
|
||||
return_value_policy policy,
|
||||
handle parent,
|
||||
|
||||
@@ -167,6 +167,7 @@ set(PYBIND11_TEST_FILES
|
||||
test_operator_overloading
|
||||
test_pickling
|
||||
test_potentially_slicing_weak_ptr
|
||||
test_pytorch_shared_ptr_cast_regression
|
||||
test_python_multiple_inheritance
|
||||
test_pytypes
|
||||
test_scoped_critical_section
|
||||
|
||||
@@ -204,3 +204,15 @@ def test_non_smart_holder_member_type_with_smart_holder_owner_aliases_member():
|
||||
legacy = obj.legacy
|
||||
legacy.value = 13
|
||||
assert obj.legacy.value == 13
|
||||
|
||||
|
||||
def test_non_smart_holder_member_type_with_smart_holder_owner_aliases_member_multiple_reads():
|
||||
obj = m.ShWithSimpleStructMember()
|
||||
|
||||
a = obj.legacy
|
||||
b = obj.legacy
|
||||
|
||||
a.value = 13
|
||||
|
||||
assert b.value == 13
|
||||
assert obj.legacy.value == 13
|
||||
|
||||
62
tests/test_pytorch_shared_ptr_cast_regression.cpp
Normal file
62
tests/test_pytorch_shared_ptr_cast_regression.cpp
Normal file
@@ -0,0 +1,62 @@
|
||||
#include "pybind11_tests.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#if defined(__clang__)
|
||||
# if __has_warning("-Wdeprecated-copy-with-user-provided-dtor")
|
||||
# pragma clang diagnostic error "-Wdeprecated-copy-with-user-provided-dtor"
|
||||
# endif
|
||||
# if __has_warning("-Wdeprecated-copy-with-dtor")
|
||||
# pragma clang diagnostic error "-Wdeprecated-copy-with-dtor"
|
||||
# endif
|
||||
#endif
|
||||
|
||||
namespace test_pytorch_regressions {
|
||||
|
||||
// Directly extracted from PyTorch patterns that regressed in CI.
|
||||
struct TracingState : std::enable_shared_from_this<TracingState> {
|
||||
TracingState() = default;
|
||||
~TracingState() = default;
|
||||
int value = 0;
|
||||
};
|
||||
|
||||
const std::shared_ptr<TracingState> &get_tracing_state() {
|
||||
static std::shared_ptr<TracingState> state = std::make_shared<TracingState>();
|
||||
return state;
|
||||
}
|
||||
|
||||
struct InterfaceType {
|
||||
~InterfaceType() = default;
|
||||
int value = 0;
|
||||
};
|
||||
using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
|
||||
|
||||
struct CompilationUnit {
|
||||
InterfaceTypePtr iface = std::make_shared<InterfaceType>();
|
||||
|
||||
InterfaceTypePtr get_interface(const std::string &) const { return iface; }
|
||||
};
|
||||
|
||||
} // namespace test_pytorch_regressions
|
||||
|
||||
TEST_SUBMODULE(pybind11_pytorch_regressions, m) {
|
||||
using namespace test_pytorch_regressions;
|
||||
|
||||
py::class_<TracingState, std::shared_ptr<TracingState>>(m, "TracingState")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("value", &TracingState::value);
|
||||
|
||||
m.def("_get_tracing_state", []() { return get_tracing_state(); });
|
||||
|
||||
py::class_<InterfaceType, InterfaceTypePtr>(m, "InterfaceType")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("value", &InterfaceType::value);
|
||||
|
||||
py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(m, "CompilationUnit")
|
||||
.def(py::init<>())
|
||||
.def("get_interface",
|
||||
[](const std::shared_ptr<CompilationUnit> &self, const std::string &name) {
|
||||
return self->get_interface(name);
|
||||
});
|
||||
}
|
||||
25
tests/test_pytorch_shared_ptr_cast_regression.py
Normal file
25
tests/test_pytorch_shared_ptr_cast_regression.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pybind11_tests import pybind11_pytorch_regressions as m
|
||||
|
||||
|
||||
def test_pytorch_like_get_tracing_state_aliases_singleton_shared_ptr():
|
||||
a = m._get_tracing_state()
|
||||
b = m._get_tracing_state()
|
||||
|
||||
a.value = 17
|
||||
|
||||
assert b.value == 17
|
||||
assert m._get_tracing_state().value == 17
|
||||
|
||||
|
||||
def test_pytorch_like_compilation_unit_get_interface_aliases_member_shared_ptr():
|
||||
cu = m.CompilationUnit()
|
||||
|
||||
a = cu.get_interface("iface")
|
||||
b = cu.get_interface("iface")
|
||||
|
||||
a.value = 23
|
||||
|
||||
assert b.value == 23
|
||||
assert cu.get_interface("iface").value == 23
|
||||
Reference in New Issue
Block a user