Add Context and Endpoint classes to enable non-Communicator use-cases (#166)

This PR implements and closes #137. The new `Endpoint` and `Context`
classes expose the connection establishing functionality from
`Communicator`, which now is only responsible for tying together the
bootstrapper with a context.

The largest breaking change here is that
`Communicator.connectOnSetup(...)` now returns the `Connection` wrapped
inside a `NonblockingFuture`. This is because with the way `Context` is
implemented a `Connection` is now fully initialized on construction.

Some smaller breaking API changes from this change are that
`RegisteredMemory` no longer has a `rank()` function (as there maybe no
concept of rank), and similarly `Connection` has no `remoteRank()` and
`tag()` functions. The latter are replaced by `remoteRankOf` and `tagOf`
functions in `Communicator`.

A new `EndpointConfig` class is introduced to avoid duplication of the
IB configuration parameters in the APIs of `Context` and `Communicator`.
The usual usage pattern of just passing in a `Transport` still works due
to an implicit conversion into `EndpointConfig`.

Miscellaneous changes:
-Cleans up how the PIMPL pattern is applied by making both the `Impl`
struct and the `pimpl_` pointers private for all relevant classes in the
core API.
-Enables ctest to be run from the build root directory.
This commit is contained in:
Olli Saarikivi
2023-09-05 22:10:04 -07:00
committed by GitHub
parent 858e381829
commit 828be48b21
25 changed files with 626 additions and 327 deletions

View File

@@ -105,7 +105,6 @@ void register_core(nb::module_& m) {
.def(nb::init<>())
.def("data", &RegisteredMemory::data)
.def("size", &RegisteredMemory::size)
.def("rank", &RegisteredMemory::rank)
.def("transports", &RegisteredMemory::transports)
.def("serialize", &RegisteredMemory::serialize)
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
@@ -120,16 +119,42 @@ void register_core(nb::module_& m) {
},
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
.def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7)
.def("remote_rank", &Connection::remoteRank)
.def("tag", &Connection::tag)
.def("transport", &Connection::transport)
.def("remote_transport", &Connection::remoteTransport);
nb::class_<Endpoint>(m, "Endpoint")
.def("transport", &Endpoint::transport)
.def("serialize", &Endpoint::serialize)
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def_rw("transport", &EndpointConfig::transport)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend);
nb::class_<Context>(m, "Context")
.def(nb::init<>())
.def(
"register_memory",
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
return self->registerMemory((void*)ptr, size, transports);
},
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
def_nonblocking_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
nb::class_<Communicator>(m, "Communicator")
.def(nb::init<std::shared_ptr<Bootstrap>>(), nb::arg("bootstrap"))
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
nb::arg("context") = nullptr)
.def("bootstrap", &Communicator::bootstrap)
.def("context", &Communicator::context)
.def(
"register_memory",
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
@@ -140,8 +165,9 @@ void register_core(nb::module_& m) {
nb::arg("tag"))
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1,
nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64)
nb::arg("localConfig"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf)
.def("setup", &Communicator::setup);
}