diff --git a/include/pybind11/detail/smart_holder_type_casters.h b/include/pybind11/detail/smart_holder_type_casters.h index 3133aba31..78e015c93 100644 --- a/include/pybind11/detail/smart_holder_type_casters.h +++ b/include/pybind11/detail/smart_holder_type_casters.h @@ -18,6 +18,7 @@ #include "type_caster_base.h" #include "typeid.h" +#include #include #include #include @@ -343,6 +344,18 @@ struct smart_holder_type_caster_class_hooks : smart_holder_type_caster_base_tag } }; +struct shared_ptr_trampoline_self_life_support { + PyObject *self; + explicit shared_ptr_trampoline_self_life_support(instance *inst) + : self{reinterpret_cast(inst)} { + Py_INCREF(self); + } + void operator()(void *) { + gil_scoped_acquire gil; + Py_DECREF(self); + } +}; + template struct smart_holder_type_caster_load { using holder_type = pybindit::memory::smart_holder; @@ -376,14 +389,6 @@ struct smart_holder_type_caster_load { return *raw_ptr; } - struct shared_ptr_dec_ref_deleter { - PyObject *self; - void operator()(void *) { - gil_scoped_acquire gil; - Py_DECREF(self); - } - }; - std::shared_ptr loaded_as_shared_ptr() const { if (load_impl.unowned_void_ptr_from_direct_conversion != nullptr) throw cast_error("Unowned pointer from direct conversion cannot be converted to a" @@ -404,24 +409,26 @@ struct smart_holder_type_caster_load { std::shared_ptr released_ptr = vptr_gd_ptr->released_ptr.lock(); if (released_ptr) return std::shared_ptr(released_ptr, type_raw_ptr); - auto self = reinterpret_cast(load_impl.loaded_v_h.inst); - Py_INCREF(self); - std::shared_ptr to_be_released(type_raw_ptr, shared_ptr_dec_ref_deleter{self}); + std::shared_ptr to_be_released( + type_raw_ptr, + shared_ptr_trampoline_self_life_support(load_impl.loaded_v_h.inst)); vptr_gd_ptr->released_ptr = to_be_released; return to_be_released; } - if (std::get_deleter(hld.vptr) != nullptr) { + auto sptsls_ptr = std::get_deleter(hld.vptr); + if (sptsls_ptr != nullptr) { // This code is reachable only if there are multiple registered_instances for the // same pointee. - // SMART_HOLDER_WIP: keep weak_ref - std::shared_ptr void_shd_ptr = hld.template as_shared_ptr(); - return std::shared_ptr(void_shd_ptr, type_raw_ptr); + assert(reinterpret_cast(load_impl.loaded_v_h.inst) + != sptsls_ptr->self); + return std::shared_ptr( + type_raw_ptr, + shared_ptr_trampoline_self_life_support(load_impl.loaded_v_h.inst)); } if (!pybindit::memory::type_has_shared_from_this(type_raw_ptr)) { - // SMART_HOLDER_WIP: keep weak_ref - auto self = reinterpret_cast(load_impl.loaded_v_h.inst); - Py_INCREF(self); - return std::shared_ptr(type_raw_ptr, shared_ptr_dec_ref_deleter{self}); + return std::shared_ptr( + type_raw_ptr, + shared_ptr_trampoline_self_life_support(load_impl.loaded_v_h.inst)); } if (hld.vptr_is_external_shared_ptr) { pybind11_fail("smart_holder_type_casters loaded_as_shared_ptr failure: not " diff --git a/tests/test_class_sh_trampoline_shared_from_this.py b/tests/test_class_sh_trampoline_shared_from_this.py index 3f2b86f7d..a80a3ce4f 100644 --- a/tests/test_class_sh_trampoline_shared_from_this.py +++ b/tests/test_class_sh_trampoline_shared_from_this.py @@ -170,6 +170,20 @@ def test_multiple_registered_instances_for_same_pointee(): assert obj_pt.attachment_in_dict == "Obj0" else: assert not hasattr(obj_pt, "attachment_in_dict") + assert obj0.history == "PySft" + break # Comment out for manual leak checking (use `top` command). + + +def test_multiple_registered_instances_for_same_pointee_leak(): + obj0 = PySft("") + while True: + stash1 = m.SftSharedPtrStash(1) + stash1.Add(m.Sft(obj0)) + assert stash1.use_count(0) == 1 + stash1.Add(m.Sft(obj0)) + assert stash1.use_count(0) == 1 + assert stash1.use_count(1) == 1 + assert obj0.history == "" break # Comment out for manual leak checking (use `top` command). diff --git a/tests/test_class_sh_trampoline_shared_ptr_cpp_arg.py b/tests/test_class_sh_trampoline_shared_ptr_cpp_arg.py index a3d74e8a6..ec72800f3 100644 --- a/tests/test_class_sh_trampoline_shared_ptr_cpp_arg.py +++ b/tests/test_class_sh_trampoline_shared_ptr_cpp_arg.py @@ -130,7 +130,7 @@ def test_infinite(): tester = m.SpBaseTester() while True: tester.set_object(m.SpBase()) - return # Comment out for manual leak checking (use `top` command). + break # Comment out for manual leak checking (use `top` command). def test_std_make_shared_factory(): @@ -139,4 +139,6 @@ def test_std_make_shared_factory(): super(PyChild, self).__init__(0) obj = PyChild() - assert m.pass_through_shd_ptr(obj) is obj + while True: + assert m.pass_through_shd_ptr(obj) is obj + break # Comment out for manual leak checking (use `top` command).