mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
Add Context and Endpoint classes to enable non-Communicator use-cases (#166)
This PR implements and closes #137. The new `Endpoint` and `Context` classes expose the connection establishing functionality from `Communicator`, which now is only responsible for tying together the bootstrapper with a context. The largest breaking change here is that `Communicator.connectOnSetup(...)` now returns the `Connection` wrapped inside a `NonblockingFuture`. This is because with the way `Context` is implemented a `Connection` is now fully initialized on construction. Some smaller breaking API changes from this change are that `RegisteredMemory` no longer has a `rank()` function (as there maybe no concept of rank), and similarly `Connection` has no `remoteRank()` and `tag()` functions. The latter are replaced by `remoteRankOf` and `tagOf` functions in `Communicator`. A new `EndpointConfig` class is introduced to avoid duplication of the IB configuration parameters in the APIs of `Context` and `Communicator`. The usual usage pattern of just passing in a `Transport` still works due to an implicit conversion into `EndpointConfig`. Miscellaneous changes: -Cleans up how the PIMPL pattern is applied by making both the `Impl` struct and the `pimpl_` pointers private for all relevant classes in the core API. -Enables ctest to be run from the build root directory.
This commit is contained in:
@@ -105,7 +105,6 @@ void register_core(nb::module_& m) {
|
||||
.def(nb::init<>())
|
||||
.def("data", &RegisteredMemory::data)
|
||||
.def("size", &RegisteredMemory::size)
|
||||
.def("rank", &RegisteredMemory::rank)
|
||||
.def("transports", &RegisteredMemory::transports)
|
||||
.def("serialize", &RegisteredMemory::serialize)
|
||||
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
|
||||
@@ -120,16 +119,42 @@ void register_core(nb::module_& m) {
|
||||
},
|
||||
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
|
||||
.def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7)
|
||||
.def("remote_rank", &Connection::remoteRank)
|
||||
.def("tag", &Connection::tag)
|
||||
.def("transport", &Connection::transport)
|
||||
.def("remote_transport", &Connection::remoteTransport);
|
||||
|
||||
nb::class_<Endpoint>(m, "Endpoint")
|
||||
.def("transport", &Endpoint::transport)
|
||||
.def("serialize", &Endpoint::serialize)
|
||||
.def_static("deserialize", &Endpoint::deserialize, nb::arg("data"));
|
||||
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def_rw("transport", &EndpointConfig::transport)
|
||||
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
|
||||
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
|
||||
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
|
||||
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend);
|
||||
|
||||
nb::class_<Context>(m, "Context")
|
||||
.def(nb::init<>())
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
return self->registerMemory((void*)ptr, size, transports);
|
||||
},
|
||||
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
|
||||
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
|
||||
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
|
||||
|
||||
def_nonblocking_future<RegisteredMemory>(m, "RegisteredMemory");
|
||||
def_nonblocking_future<std::shared_ptr<Connection>>(m, "shared_ptr_Connection");
|
||||
|
||||
nb::class_<Communicator>(m, "Communicator")
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>>(), nb::arg("bootstrap"))
|
||||
.def(nb::init<std::shared_ptr<Bootstrap>, std::shared_ptr<Context>>(), nb::arg("bootstrap"),
|
||||
nb::arg("context") = nullptr)
|
||||
.def("bootstrap", &Communicator::bootstrap)
|
||||
.def("context", &Communicator::context)
|
||||
.def(
|
||||
"register_memory",
|
||||
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
|
||||
@@ -140,8 +165,9 @@ void register_core(nb::module_& m) {
|
||||
nb::arg("tag"))
|
||||
.def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag"))
|
||||
.def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"),
|
||||
nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1,
|
||||
nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64)
|
||||
nb::arg("localConfig"))
|
||||
.def("remote_rank_of", &Communicator::remoteRankOf)
|
||||
.def("tag_of", &Communicator::tagOf)
|
||||
.def("setup", &Communicator::setup);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user