Update EndpointConfig interfaces (#651)

* Separate IB-specific options into a nested struct
* Enable `connect()` by an `Endpoint`, not only by `EndpointConfig`
* Other minor changes
This commit is contained in:
Changho Hwang
2025-10-22 10:39:39 -07:00
committed by GitHub
parent 610db6f023
commit 200cdf946e
8 changed files with 153 additions and 110 deletions

View File

@@ -124,6 +124,17 @@ void register_core(nb::module_& m) {
.def_rw("id", &Device::id)
.def("__str__", [](const Device& self) { return std::to_string(self); });
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
.def(nb::init<>())
.def(nb::init<int, int, int, int>(), nb::arg("maxCqSize") = EndpointConfig::Ib::DefaultMaxCqSize,
nb::arg("maxCqPollNum") = EndpointConfig::Ib::DefaultMaxCqPollNum,
nb::arg("maxSendWr") = EndpointConfig::Ib::DefaultMaxSendWr,
nb::arg("maxWrPerSend") = EndpointConfig::Ib::DefaultMaxWrPerSend)
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
.def(nb::init<>())
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
@@ -158,17 +169,23 @@ void register_core(nb::module_& m) {
nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, Device, int, int, int, int, int>(), nb::arg("transport"), nb::arg("device"),
nb::arg("ibMaxCqSize") = EndpointConfig::DefaultMaxCqSize,
nb::arg("ibMaxCqPollNum") = EndpointConfig::DefaultMaxCqPollNum,
nb::arg("ibMaxSendWr") = EndpointConfig::DefaultMaxSendWr,
nb::arg("ibMaxWrPerSend") = EndpointConfig::DefaultMaxWrPerSend, nb::arg("maxWriteQueueSize") = -1)
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
nb::arg("maxWriteQueueSize") = -1, nb::arg("ib") = EndpointConfig::Ib{})
.def_rw("transport", &EndpointConfig::transport)
.def_rw("device", &EndpointConfig::device)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend)
.def_rw("ib", &EndpointConfig::ib)
.def_prop_rw(
"ib_max_cq_size", [](EndpointConfig& self) { return self.ib.maxCqSize; },
[](EndpointConfig& self, int v) { self.ib.maxCqSize = v; })
.def_prop_rw(
"ib_max_cq_poll_num", [](EndpointConfig& self) { return self.ib.maxCqPollNum; },
[](EndpointConfig& self, int v) { self.ib.maxCqPollNum = v; })
.def_prop_rw(
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
.def_prop_rw(
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
nb::class_<Context>(m, "Context")
@@ -212,13 +229,15 @@ void register_core(nb::module_& m) {
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("connect",
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const Endpoint&, int, int)>(
&Communicator::connect),
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
nb::arg("localEndpoint"), nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("connect", [](Communicator* self, const EndpointConfig& localConfig, int remoteRank,
int tag = 0) { return self->connect(localConfig, remoteRank, tag); })
.def(
"connect",
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
return self->connect(std::move(localConfig), remoteRank, tag);
[](Communicator* self, int remoteRank, int tag, const EndpointConfig& localConfig) {
return self->connect(localConfig, remoteRank, tag);
},
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def(