fix: Invert input/output context if a Python function is called from C++ (#6055)

When the Callable itself is an input (parameter) to a C++ function, its
arguments are outputs (C++ passes them to the Python callback), and vice
versa.
Therefore, we must invert them if C++ calls a Python function, but keep
them the same in the other direction.
This commit is contained in:
Matthias Klumpp
2026-05-16 20:28:05 +02:00
committed by GitHub
parent 00a9c6244c
commit 5a5f21deb6
6 changed files with 26 additions and 18 deletions

View File

@@ -222,5 +222,10 @@ constexpr descr<N + 4, Ts...> return_descr(const descr<N, Ts...> &descr) {
return const_name("@$") + descr + const_name("@!");
}
template <size_t N, typename... Ts>
constexpr descr<N + 4, Ts...> inv_descr(const descr<N, Ts...> &descr) {
return const_name("@~") + descr + const_name("@!");
}
PYBIND11_NAMESPACE_END(detail)
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -138,9 +138,8 @@ public:
PYBIND11_TYPE_CASTER(
type,
const_name("collections.abc.Callable[[")
+ ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster<Args>::name)...)
+ const_name("], ") + ::pybind11::detail::return_descr(make_caster<retval_type>::name)
+ const_name("]"));
+ ::pybind11::detail::concat(::pybind11::detail::inv_descr(make_caster<Args>::name)...)
+ const_name("], ") + make_caster<retval_type>::name + const_name("]"));
};
PYBIND11_NAMESPACE_END(detail)

View File

@@ -185,7 +185,8 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel
signature += *++pc;
} else if (c == '@') {
// `@^ ... @!` and `@$ ... @!` are used to force arg/return value type (see
// typing::Callable/detail::arg_descr/detail::return_descr)
// typing::Callable/detail::arg_descr/detail::return_descr).
// `@~ ... @!` inverts the current context (see detail::inv_descr).
if (*(pc + 1) == '^') {
is_return_value.emplace(false);
++pc;
@@ -196,6 +197,11 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel
++pc;
continue;
}
if (*(pc + 1) == '~') {
is_return_value.emplace(!is_return_value.top());
++pc;
continue;
}
if (*(pc + 1) == '!') {
is_return_value.pop();
++pc;

View File

@@ -194,9 +194,8 @@ struct handle_type_name<typing::Callable<Return(Args...)>> {
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
static constexpr auto name
= const_name("collections.abc.Callable[[")
+ ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster<Args>::name)...)
+ const_name("], ") + ::pybind11::detail::return_descr(make_caster<retval_type>::name)
+ const_name("]");
+ ::pybind11::detail::concat(::pybind11::detail::inv_descr(make_caster<Args>::name)...)
+ const_name("], ") + make_caster<retval_type>::name + const_name("]");
};
template <typename Return>
@@ -204,8 +203,7 @@ struct handle_type_name<typing::Callable<Return(ellipsis)>> {
// PEP 484 specifies this syntax for defining only return types of callables
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
static constexpr auto name = const_name("collections.abc.Callable[..., ")
+ ::pybind11::detail::return_descr(make_caster<retval_type>::name)
+ const_name("]");
+ make_caster<retval_type>::name + const_name("]");
};
template <typename T>

View File

@@ -140,7 +140,7 @@ def test_cpp_function_roundtrip():
def test_function_signatures(doc):
assert (
doc(m.test_callback3)
== "test_callback3(arg0: collections.abc.Callable[[typing.SupportsInt | typing.SupportsIndex], int]) -> str"
== "test_callback3(arg0: collections.abc.Callable[[int], typing.SupportsInt | typing.SupportsIndex]) -> str"
)
assert (
doc(m.test_callback4)

View File

@@ -976,14 +976,14 @@ def test_iterator_annotations(doc):
def test_fn_annotations(doc):
assert (
doc(m.annotate_fn)
== "annotate_fn(arg0: collections.abc.Callable[[list[str], str], int]) -> None"
== "annotate_fn(arg0: collections.abc.Callable[[list[str], str], typing.SupportsInt | typing.SupportsIndex]) -> None"
)
def test_fn_return_only(doc):
assert (
doc(m.annotate_fn_only_return)
== "annotate_fn_only_return(arg0: collections.abc.Callable[..., int]) -> None"
== "annotate_fn_only_return(arg0: collections.abc.Callable[..., typing.SupportsInt | typing.SupportsIndex]) -> None"
)
@@ -1085,7 +1085,7 @@ def test_literal(doc):
)
assert (
doc(m.identity_literal_arrow_with_callable)
== 'identity_literal_arrow_with_callable(arg0: collections.abc.Callable[[typing.Literal["->"], float | int], float]) -> collections.abc.Callable[[typing.Literal["->"], float | int], float]'
== 'identity_literal_arrow_with_callable(arg0: collections.abc.Callable[[typing.Literal["->"], float], float | int]) -> collections.abc.Callable[[typing.Literal["->"], float | int], float]'
)
assert (
doc(m.identity_literal_all_special_chars)
@@ -1325,27 +1325,27 @@ def test_arg_return_type_hints(doc, backport_typehints):
# Callable<R(A)> identity
assert (
doc(m.identity_callable)
== "identity_callable(arg0: collections.abc.Callable[[float | int], float]) -> collections.abc.Callable[[float | int], float]"
== "identity_callable(arg0: collections.abc.Callable[[float], float | int]) -> collections.abc.Callable[[float | int], float]"
)
# Callable<R(...)> identity
assert (
doc(m.identity_callable_ellipsis)
== "identity_callable_ellipsis(arg0: collections.abc.Callable[..., float]) -> collections.abc.Callable[..., float]"
== "identity_callable_ellipsis(arg0: collections.abc.Callable[..., float | int]) -> collections.abc.Callable[..., float]"
)
# Nested Callable<R(A)> identity
assert (
doc(m.identity_nested_callable)
== "identity_nested_callable(arg0: collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float | int], float]]) -> collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float | int], float]]"
== "identity_nested_callable(arg0: collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float], float | int]]) -> collections.abc.Callable[[collections.abc.Callable[[float], float | int]], collections.abc.Callable[[float | int], float]]"
)
# Callable<R(A)>
assert (
doc(m.apply_callable)
== "apply_callable(arg0: float | int, arg1: collections.abc.Callable[[float | int], float]) -> float"
== "apply_callable(arg0: float | int, arg1: collections.abc.Callable[[float], float | int]) -> float"
)
# Callable<R(...)>
assert (
doc(m.apply_callable_ellipsis)
== "apply_callable_ellipsis(arg0: float | int, arg1: collections.abc.Callable[..., float]) -> float"
== "apply_callable_ellipsis(arg0: float | int, arg1: collections.abc.Callable[..., float | int]) -> float"
)
# Union<T1, T2>
assert (