Allow arbitrary class_ template option ordering

The current pybind11::class_<Type, Holder, Trampoline> fixed template
ordering results in a requirement to repeat the Holder with its default
value (std::unique_ptr<Type>) argument, which is a little bit annoying:
it needs to be specified not because we want to override the default,
but rather because we need to specify the third argument.

This commit removes this limitation by making the class_ template take
the type name plus a parameter pack of options.  It then extracts the
first valid holder type and the first subclass type for holder_type and
trampoline type_alias, respectively.  (If unfound, both fall back to
their current defaults, `std::unique_ptr<type>` and `type`,
respectively).  If any unmatched template arguments are provided, a
static assertion fails.

What this means is that you can specify or omit the arguments in any
order:

    py::class_<A, PyA> c1(m, "A");
    py::class_<B, PyB, std::shared_ptr<B>> c2(m, "B");
    py::class_<C, std::shared_ptr<C>, PyB> c3(m, "C");

It also allows future class attributes (such as base types in the next
commit) to be passed as class template types rather than needing to use
a py::base<> wrapper.
This commit is contained in:
Jason Rhinelander
2016-09-06 12:17:06 -04:00
parent a3dbdc67f5
commit 5fffe200e3
10 changed files with 187 additions and 52 deletions

View File

@@ -253,31 +253,31 @@ public:
void initialize_inherited_virtuals(py::module &m) {
// Method 1: repeat
py::class_<A_Repeat, std::unique_ptr<A_Repeat>, PyA_Repeat>(m, "A_Repeat")
py::class_<A_Repeat, PyA_Repeat>(m, "A_Repeat")
.def(py::init<>())
.def("unlucky_number", &A_Repeat::unlucky_number)
.def("say_something", &A_Repeat::say_something)
.def("say_everything", &A_Repeat::say_everything);
py::class_<B_Repeat, std::unique_ptr<B_Repeat>, PyB_Repeat>(m, "B_Repeat", py::base<A_Repeat>())
py::class_<B_Repeat, PyB_Repeat>(m, "B_Repeat", py::base<A_Repeat>())
.def(py::init<>())
.def("lucky_number", &B_Repeat::lucky_number);
py::class_<C_Repeat, std::unique_ptr<C_Repeat>, PyC_Repeat>(m, "C_Repeat", py::base<B_Repeat>())
py::class_<C_Repeat, PyC_Repeat>(m, "C_Repeat", py::base<B_Repeat>())
.def(py::init<>());
py::class_<D_Repeat, std::unique_ptr<D_Repeat>, PyD_Repeat>(m, "D_Repeat", py::base<C_Repeat>())
py::class_<D_Repeat, PyD_Repeat>(m, "D_Repeat", py::base<C_Repeat>())
.def(py::init<>());
// Method 2: Templated trampolines
py::class_<A_Tpl, std::unique_ptr<A_Tpl>, PyA_Tpl<>>(m, "A_Tpl")
py::class_<A_Tpl, PyA_Tpl<>>(m, "A_Tpl")
.def(py::init<>())
.def("unlucky_number", &A_Tpl::unlucky_number)
.def("say_something", &A_Tpl::say_something)
.def("say_everything", &A_Tpl::say_everything);
py::class_<B_Tpl, std::unique_ptr<B_Tpl>, PyB_Tpl<>>(m, "B_Tpl", py::base<A_Tpl>())
py::class_<B_Tpl, PyB_Tpl<>>(m, "B_Tpl", py::base<A_Tpl>())
.def(py::init<>())
.def("lucky_number", &B_Tpl::lucky_number);
py::class_<C_Tpl, std::unique_ptr<C_Tpl>, PyB_Tpl<C_Tpl>>(m, "C_Tpl", py::base<B_Tpl>())
py::class_<C_Tpl, PyB_Tpl<C_Tpl>>(m, "C_Tpl", py::base<B_Tpl>())
.def(py::init<>());
py::class_<D_Tpl, std::unique_ptr<D_Tpl>, PyB_Tpl<D_Tpl>>(m, "D_Tpl", py::base<C_Tpl>())
py::class_<D_Tpl, PyB_Tpl<D_Tpl>>(m, "D_Tpl", py::base<C_Tpl>())
.def(py::init<>());
};
@@ -287,7 +287,7 @@ test_initializer virtual_functions([](py::module &m) {
/* Important: indicate the trampoline class PyExampleVirt using the third
argument to py::class_. The second argument with the unique pointer
is simply the default holder type used by pybind11. */
py::class_<ExampleVirt, std::unique_ptr<ExampleVirt>, PyExampleVirt>(m, "ExampleVirt")
py::class_<ExampleVirt, PyExampleVirt>(m, "ExampleVirt")
.def(py::init<int>())
/* Reference original class in function definitions */
.def("run", &ExampleVirt::run)
@@ -301,7 +301,7 @@ test_initializer virtual_functions([](py::module &m) {
.def(py::init<int, int>());
#if !defined(__INTEL_COMPILER)
py::class_<NCVirt, std::unique_ptr<NCVirt>, NCVirtTrampoline>(m, "NCVirt")
py::class_<NCVirt, NCVirtTrampoline>(m, "NCVirt")
.def(py::init<>())
.def("get_noncopyable", &NCVirt::get_noncopyable)
.def("get_movable", &NCVirt::get_movable)