mirror of
https://github.com/pybind/pybind11.git
synced 2026-05-25 07:14:50 +00:00
fix: Invert input/output context if a Python function is called from C++ (#6055)
When the Callable itself is an input (parameter) to a C++ function, its arguments are outputs (C++ passes them to the Python callback), and vice versa. Therefore, we must invert them if C++ calls a Python function, but keep them the same in the other direction.
This commit is contained in:
@@ -222,5 +222,10 @@ constexpr descr<N + 4, Ts...> return_descr(const descr<N, Ts...> &descr) {
|
||||
return const_name("@$") + descr + const_name("@!");
|
||||
}
|
||||
|
||||
template <size_t N, typename... Ts>
|
||||
constexpr descr<N + 4, Ts...> inv_descr(const descr<N, Ts...> &descr) {
|
||||
return const_name("@~") + descr + const_name("@!");
|
||||
}
|
||||
|
||||
PYBIND11_NAMESPACE_END(detail)
|
||||
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
||||
@@ -138,9 +138,8 @@ public:
|
||||
PYBIND11_TYPE_CASTER(
|
||||
type,
|
||||
const_name("collections.abc.Callable[[")
|
||||
+ ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster<Args>::name)...)
|
||||
+ const_name("], ") + ::pybind11::detail::return_descr(make_caster<retval_type>::name)
|
||||
+ const_name("]"));
|
||||
+ ::pybind11::detail::concat(::pybind11::detail::inv_descr(make_caster<Args>::name)...)
|
||||
+ const_name("], ") + make_caster<retval_type>::name + const_name("]"));
|
||||
};
|
||||
|
||||
PYBIND11_NAMESPACE_END(detail)
|
||||
|
||||
@@ -185,7 +185,8 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel
|
||||
signature += *++pc;
|
||||
} else if (c == '@') {
|
||||
// `@^ ... @!` and `@$ ... @!` are used to force arg/return value type (see
|
||||
// typing::Callable/detail::arg_descr/detail::return_descr)
|
||||
// typing::Callable/detail::arg_descr/detail::return_descr).
|
||||
// `@~ ... @!` inverts the current context (see detail::inv_descr).
|
||||
if (*(pc + 1) == '^') {
|
||||
is_return_value.emplace(false);
|
||||
++pc;
|
||||
@@ -196,6 +197,11 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel
|
||||
++pc;
|
||||
continue;
|
||||
}
|
||||
if (*(pc + 1) == '~') {
|
||||
is_return_value.emplace(!is_return_value.top());
|
||||
++pc;
|
||||
continue;
|
||||
}
|
||||
if (*(pc + 1) == '!') {
|
||||
is_return_value.pop();
|
||||
++pc;
|
||||
|
||||
@@ -194,9 +194,8 @@ struct handle_type_name<typing::Callable<Return(Args...)>> {
|
||||
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
|
||||
static constexpr auto name
|
||||
= const_name("collections.abc.Callable[[")
|
||||
+ ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster<Args>::name)...)
|
||||
+ const_name("], ") + ::pybind11::detail::return_descr(make_caster<retval_type>::name)
|
||||
+ const_name("]");
|
||||
+ ::pybind11::detail::concat(::pybind11::detail::inv_descr(make_caster<Args>::name)...)
|
||||
+ const_name("], ") + make_caster<retval_type>::name + const_name("]");
|
||||
};
|
||||
|
||||
template <typename Return>
|
||||
@@ -204,8 +203,7 @@ struct handle_type_name<typing::Callable<Return(ellipsis)>> {
|
||||
// PEP 484 specifies this syntax for defining only return types of callables
|
||||
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
|
||||
static constexpr auto name = const_name("collections.abc.Callable[..., ")
|
||||
+ ::pybind11::detail::return_descr(make_caster<retval_type>::name)
|
||||
+ const_name("]");
|
||||
+ make_caster<retval_type>::name + const_name("]");
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -140,7 +140,7 @@ def test_cpp_function_roundtrip():
|
||||
def test_function_signatures(doc):
|
||||
assert (
|
||||
doc(m.test_callback3)
|
||||
== "test_callback3(arg0: collections.abc.Callable[[typing.SupportsInt | typing.SupportsIndex], int]) -> str"
|
||||
== "test_callback3(arg0: collections.abc.Callable[[int], typing.SupportsInt | typing.SupportsIndex]) -> str"
|
||||
)
|
||||
assert (
|
||||
doc(m.test_callback4)
|
||||
|
||||
@@ -976,14 +976,14 @@ def test_iterator_annotations(doc):
|
||||
def test_fn_annotations(doc):
|
||||
assert (
|
||||
doc(m.annotate_fn)
|
||||
== "annotate_fn(arg0: collections.abc.Callable[[list[str], str], int]) -> None"
|
||||
== "annotate_fn(arg0: collections.abc.Callable[[list[str], str], typing.SupportsInt | typing.SupportsIndex]) -> None"
|
||||
)
|
||||
|
||||
|
||||
def test_fn_return_only(doc):
|
||||
assert (
|
||||
doc(m.annotate_fn_only_return)
|
||||
== "annotate_fn_only_return(arg0: collections.abc.Callable[..., int]) -> None"
|
||||
== "annotate_fn_only_return(arg0: collections.abc.Callable[..., typing.SupportsInt | typing.SupportsIndex]) -> None"
|
||||
)
|
||||
|
||||
|
||||
@@ -1085,7 +1085,7 @@ def test_literal(doc):
|
||||
)
|
||||
assert (
|
||||
doc(m.identity_literal_arrow_with_callable)
|
||||
== 'identity_literal_arrow_with_callable(arg0: collections.abc.Callable[[typing.Literal["->"], float | int], float]) -> collections.abc.Callable[[typing.Literal["->"], float | int], float]'
|
||||
== 'identity_literal_arrow_with_callable(arg0: collections.abc.Callable[[typing.Literal["->"], float], float | int]) -> collections.abc.Callable[[typing.Literal["->"], float | int], float]'
|
||||
)
|
||||
assert (
|
||||
doc(m.identity_literal_all_special_chars)
|
||||
@@ -1325,27 +1325,27 @@ def test_arg_return_type_hints(doc, backport_typehints):
|
||||
# Callable<R(A)> identity
|
||||
assert (
|
||||
doc(m.identity_callable)
|
||||
== "identity_callable(arg0: collections.abc.Callable[[float | int], float]) -> collections.abc.Callable[[float | int], float]"
|
||||
== "identity_callable(arg0: collections.abc.Callable[[float], float | int]) -> collections.abc.Callable[[float | int], float]"
|
||||
)
|
||||
# Callable<R(...)> identity
|
||||
assert (
|
||||
doc(m.identity_callable_ellipsis)
|
||||
== "identity_callable_ellipsis(arg0: collections.abc.Callable[..., float]) -> collections.abc.Callable[..., float]"
|
||||
== "identity_callable_ellipsis(arg0: collections.abc.Callable[..., float | int]) -> collections.abc.Callable[..., float]"
|
||||
)
|
||||
# Nested Callable<R(A)> identity
|
||||
assert (
|
||||
doc(m.identity_nested_callable)
|
||||
== "identity_nested_callable(arg0: collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float | int], float]]) -> collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float | int], float]]"
|
||||
== "identity_nested_callable(arg0: collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float], float | int]]) -> collections.abc.Callable[[collections.abc.Callable[[float], float | int]], collections.abc.Callable[[float | int], float]]"
|
||||
)
|
||||
# Callable<R(A)>
|
||||
assert (
|
||||
doc(m.apply_callable)
|
||||
== "apply_callable(arg0: float | int, arg1: collections.abc.Callable[[float | int], float]) -> float"
|
||||
== "apply_callable(arg0: float | int, arg1: collections.abc.Callable[[float], float | int]) -> float"
|
||||
)
|
||||
# Callable<R(...)>
|
||||
assert (
|
||||
doc(m.apply_callable_ellipsis)
|
||||
== "apply_callable_ellipsis(arg0: float | int, arg1: collections.abc.Callable[..., float]) -> float"
|
||||
== "apply_callable_ellipsis(arg0: float | int, arg1: collections.abc.Callable[..., float | int]) -> float"
|
||||
)
|
||||
# Union<T1, T2>
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user