From 2b4fbbd521a51f6908d3ffe6f9fd46ab03418560 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Thu, 8 Apr 2021 22:56:46 -0700 Subject: [PATCH] Bug fix for virtual_overrider_self_life_support ASAN heap-use-after-free failure. (#2942) * Porting subset of absltest code from reproducer provided by @elkhrt. Baseline for debugging ASAN heap-use-after-free. * Moving Py_DECREF to resolve ASAN heap-use-after-free failure. * Fixing trivial formatting issue. * Workaround for clang 3.6 and 3.7. --- .../virtual_overrider_self_life_support.h | 4 +- tests/CMakeLists.txt | 1 + tests/test_class_sh_trampoline_unique_ptr.cpp | 48 +++++++++++++++++++ tests/test_class_sh_trampoline_unique_ptr.py | 17 +++++++ 4 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 tests/test_class_sh_trampoline_unique_ptr.cpp create mode 100644 tests/test_class_sh_trampoline_unique_ptr.py diff --git a/include/pybind11/virtual_overrider_self_life_support.h b/include/pybind11/virtual_overrider_self_life_support.h index b4c2437d0..bf82f69a6 100644 --- a/include/pybind11/virtual_overrider_self_life_support.h +++ b/include/pybind11/virtual_overrider_self_life_support.h @@ -26,10 +26,10 @@ struct virtual_overrider_self_life_support { void *value_void_ptr = loaded_v_h.value_ptr(); if (value_void_ptr != nullptr) { PyGILState_STATE threadstate = PyGILState_Ensure(); - Py_DECREF((PyObject *) loaded_v_h.inst); - loaded_v_h.value_ptr() = nullptr; + loaded_v_h.value_ptr() = nullptr; loaded_v_h.holder().release_disowned(); detail::deregister_instance(loaded_v_h.inst, value_void_ptr, loaded_v_h.type); + Py_DECREF((PyObject *) loaded_v_h.inst); // Must be after deregister. PyGILState_Release(threadstate); } } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index acc9efb17..bf7705051 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -107,6 +107,7 @@ set(PYBIND11_TEST_FILES test_class_sh_factory_constructors.cpp test_class_sh_inheritance.cpp test_class_sh_trampoline_shared_ptr_cpp_arg.cpp + test_class_sh_trampoline_unique_ptr.cpp test_class_sh_unique_ptr_member.cpp test_class_sh_virtual_py_cpp_mix.cpp test_class_sh_with_alias.cpp diff --git a/tests/test_class_sh_trampoline_unique_ptr.cpp b/tests/test_class_sh_trampoline_unique_ptr.cpp new file mode 100644 index 000000000..375f14657 --- /dev/null +++ b/tests/test_class_sh_trampoline_unique_ptr.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2021 The Pybind Development Team. +// All rights reserved. Use of this source code is governed by a +// BSD-style license that can be found in the LICENSE file. + +#include "pybind11/smart_holder.h" +#include "pybind11/virtual_overrider_self_life_support.h" +#include "pybind11_tests.h" + +namespace { + +class Class { +public: + virtual ~Class() = default; + virtual std::unique_ptr clone() const = 0; + virtual int foo() const = 0; + +protected: + Class() = default; + + // Some compilers complain about implicitly defined versions of some of the following: + Class(const Class &) = default; +}; + +} // namespace + +PYBIND11_SMART_HOLDER_TYPE_CASTERS(Class) + +namespace { + +class PyClass : public Class, public py::virtual_overrider_self_life_support { +public: + std::unique_ptr clone() const override { + PYBIND11_OVERRIDE_PURE(std::unique_ptr, Class, clone); + } + + int foo() const override { PYBIND11_OVERRIDE_PURE(int, Class, foo); } +}; + +} // namespace + +TEST_SUBMODULE(class_sh_trampoline_unique_ptr, m) { + py::classh(m, "Class") + .def(py::init<>()) + .def("clone", &Class::clone) + .def("foo", &Class::foo); + + m.def("clone_and_foo", [](const Class &obj) { return obj.clone()->foo(); }); +} diff --git a/tests/test_class_sh_trampoline_unique_ptr.py b/tests/test_class_sh_trampoline_unique_ptr.py new file mode 100644 index 000000000..d43cfeb4f --- /dev/null +++ b/tests/test_class_sh_trampoline_unique_ptr.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +import pybind11_tests.class_sh_trampoline_unique_ptr as m + + +class MyClass(m.Class): + def foo(self): + return 10 + + def clone(self): + return MyClass() + + +def test_py_clone_and_foo(): + obj = MyClass() + assert obj.foo() == 10 + assert m.clone_and_foo(obj) == 10