mirror of
https://github.com/pybind/pybind11.git
synced 2026-03-14 20:27:47 +00:00
Fix template trampoline overload lookup failure
Problem ======= The template trampoline pattern documented in PR #322 has a problem with virtual method overloads in intermediate classes in the inheritance chain between the trampoline class and the base class. For example, consider the following inheritance structure, where `B` is the actual class, `PyB<B>` is the trampoline class, and `PyA<B>` is an intermediate class adding A's methods into the trampoline: PyB<B> -> PyA<B> -> B -> A Suppose PyA<B> has a method `some_method()` with a PYBIND11_OVERLOAD in it to overload the virtual `A::some_method()`. If a Python class `C` is defined that inherits from the pybind11-registered `B` and tries to provide an overriding `some_method()`, the PYBIND11_OVERLOADs declared in PyA<B> fails to find this overloaded method, and thus never invoke it (or, if pure virtual and not overridden in PyB<B>, raises an exception). This happens because the base (internal) `PYBIND11_OVERLOAD_INT` macro simply calls `get_overload(this, name)`; `get_overload()` then uses the inferred type of `this` to do a type lookup in `registered_types_cpp`. This is where it fails: `this` will be a `PyA<B> *`, but `PyA<B>` is neither the base type (`B`) nor the trampoline type (`PyB<B>`). As a result, the overload fails and we get a failed overload lookup. The fix ======= The fix is relatively simple: we can cast `this` passed to `get_overload()` to a `const B *`, which lets get_overload look up the correct class. Since trampoline classes should be derived from `B` classes anyway, this cast should be perfectly safe. This does require adding the class name as an argument to the PYBIND11_OVERLOAD_INT macro, but leaves the public macro signatures unchanged.
This commit is contained in:
@@ -69,20 +69,24 @@ def test_inheriting_repeat():
|
||||
obj = VI_AR()
|
||||
assert obj.say_something(3) == "hihihi"
|
||||
assert obj.unlucky_number() == 99
|
||||
assert obj.say_everything() == "hi 99"
|
||||
|
||||
obj = VI_AT()
|
||||
assert obj.say_something(3) == "hihihi"
|
||||
assert obj.unlucky_number() == 999
|
||||
assert obj.say_everything() == "hi 999"
|
||||
|
||||
for obj in [B_Repeat(), B_Tpl()]:
|
||||
assert obj.say_something(3) == "B says hi 3 times"
|
||||
assert obj.unlucky_number() == 13
|
||||
assert obj.lucky_number() == 7.0
|
||||
assert obj.say_everything() == "B says hi 1 times 13"
|
||||
|
||||
for obj in [C_Repeat(), C_Tpl()]:
|
||||
assert obj.say_something(3) == "B says hi 3 times"
|
||||
assert obj.unlucky_number() == 4444
|
||||
assert obj.lucky_number() == 888.0
|
||||
assert obj.say_everything() == "B says hi 1 times 4444"
|
||||
|
||||
class VI_CR(C_Repeat):
|
||||
def lucky_number(self):
|
||||
@@ -92,6 +96,7 @@ def test_inheriting_repeat():
|
||||
assert obj.say_something(3) == "B says hi 3 times"
|
||||
assert obj.unlucky_number() == 4444
|
||||
assert obj.lucky_number() == 889.25
|
||||
assert obj.say_everything() == "B says hi 1 times 4444"
|
||||
|
||||
class VI_CT(C_Tpl):
|
||||
pass
|
||||
@@ -100,6 +105,7 @@ def test_inheriting_repeat():
|
||||
assert obj.say_something(3) == "B says hi 3 times"
|
||||
assert obj.unlucky_number() == 4444
|
||||
assert obj.lucky_number() == 888.0
|
||||
assert obj.say_everything() == "B says hi 1 times 4444"
|
||||
|
||||
class VI_CCR(VI_CR):
|
||||
def lucky_number(self):
|
||||
@@ -109,6 +115,7 @@ def test_inheriting_repeat():
|
||||
assert obj.say_something(3) == "B says hi 3 times"
|
||||
assert obj.unlucky_number() == 4444
|
||||
assert obj.lucky_number() == 8892.5
|
||||
assert obj.say_everything() == "B says hi 1 times 4444"
|
||||
|
||||
class VI_CCT(VI_CT):
|
||||
def lucky_number(self):
|
||||
@@ -118,6 +125,7 @@ def test_inheriting_repeat():
|
||||
assert obj.say_something(3) == "B says hi 3 times"
|
||||
assert obj.unlucky_number() == 4444
|
||||
assert obj.lucky_number() == 888000.0
|
||||
assert obj.say_everything() == "B says hi 1 times 4444"
|
||||
|
||||
class VI_DR(D_Repeat):
|
||||
def unlucky_number(self):
|
||||
@@ -130,11 +138,13 @@ def test_inheriting_repeat():
|
||||
assert obj.say_something(3) == "B says hi 3 times"
|
||||
assert obj.unlucky_number() == 4444
|
||||
assert obj.lucky_number() == 888.0
|
||||
assert obj.say_everything() == "B says hi 1 times 4444"
|
||||
|
||||
obj = VI_DR()
|
||||
assert obj.say_something(3) == "B says hi 3 times"
|
||||
assert obj.unlucky_number() == 123
|
||||
assert obj.lucky_number() == 42.0
|
||||
assert obj.say_everything() == "B says hi 1 times 123"
|
||||
|
||||
class VI_DT(D_Tpl):
|
||||
def say_something(self, times):
|
||||
@@ -150,6 +160,28 @@ def test_inheriting_repeat():
|
||||
assert obj.say_something(3) == "VI_DT says: quack quack quack"
|
||||
assert obj.unlucky_number() == 1234
|
||||
assert obj.lucky_number() == -4.25
|
||||
assert obj.say_everything() == "VI_DT says: quack 1234"
|
||||
|
||||
class VI_DT2(VI_DT):
|
||||
def say_something(self, times):
|
||||
return "VI_DT2: " + ('QUACK' * times)
|
||||
|
||||
def unlucky_number(self):
|
||||
return -3
|
||||
|
||||
class VI_BT(B_Tpl):
|
||||
def say_something(self, times):
|
||||
return "VI_BT" * times
|
||||
def unlucky_number(self):
|
||||
return -7
|
||||
def lucky_number(self):
|
||||
return -1.375
|
||||
|
||||
obj = VI_BT()
|
||||
assert obj.say_something(3) == "VI_BTVI_BTVI_BT"
|
||||
assert obj.unlucky_number() == -7
|
||||
assert obj.lucky_number() == -1.375
|
||||
assert obj.say_everything() == "VI_BT -7"
|
||||
|
||||
@pytest.mark.skipif(not hasattr(pybind11_tests, 'NCVirt'),
|
||||
reason="NCVirt test broken on ICPC")
|
||||
|
||||
Reference in New Issue
Block a user