mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Update EndpointConfig interfaces (#651)
* Separate IB-specific options into a nested struct * Enable `connect()` by an `Endpoint`, not only by `EndpointConfig` * Other minor changes
This commit is contained in:
@@ -124,6 +124,17 @@ void register_core(nb::module_& m) {
|
||||
.def_rw("id", &Device::id)
|
||||
.def("__str__", [](const Device& self) { return std::to_string(self); });
|
||||
|
||||
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<int, int, int, int>(), nb::arg("maxCqSize") = EndpointConfig::Ib::DefaultMaxCqSize,
|
||||
nb::arg("maxCqPollNum") = EndpointConfig::Ib::DefaultMaxCqPollNum,
|
||||
nb::arg("maxSendWr") = EndpointConfig::Ib::DefaultMaxSendWr,
|
||||
nb::arg("maxWrPerSend") = EndpointConfig::Ib::DefaultMaxWrPerSend)
|
||||
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
|
||||
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
|
||||
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
|
||||
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
|
||||
|
||||
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
|
||||
.def(nb::init<>())
|
||||
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
|
||||
@@ -158,17 +169,23 @@ void register_core(nb::module_& m) {
|
||||
nb::class_<EndpointConfig>(m, "EndpointConfig")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
|
||||
.def(nb::init<Transport, Device, int, int, int, int, int>(), nb::arg("transport"), nb::arg("device"),
|
||||
nb::arg("ibMaxCqSize") = EndpointConfig::DefaultMaxCqSize,
|
||||
nb::arg("ibMaxCqPollNum") = EndpointConfig::DefaultMaxCqPollNum,
|
||||
nb::arg("ibMaxSendWr") = EndpointConfig::DefaultMaxSendWr,
|
||||
nb::arg("ibMaxWrPerSend") = EndpointConfig::DefaultMaxWrPerSend, nb::arg("maxWriteQueueSize") = -1)
|
||||
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
|
||||
nb::arg("maxWriteQueueSize") = -1, nb::arg("ib") = EndpointConfig::Ib{})
|
||||
.def_rw("transport", &EndpointConfig::transport)
|
||||
.def_rw("device", &EndpointConfig::device)
|
||||
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
|
||||
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
|
||||
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
|
||||
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend)
|
||||
.def_rw("ib", &EndpointConfig::ib)
|
||||
.def_prop_rw(
|
||||
"ib_max_cq_size", [](EndpointConfig& self) { return self.ib.maxCqSize; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxCqSize = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_cq_poll_num", [](EndpointConfig& self) { return self.ib.maxCqPollNum; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxCqPollNum = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
|
||||
.def_prop_rw(
|
||||
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
|
||||
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
|
||||
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
|
||||
|
||||
nb::class_<Context>(m, "Context")
|
||||
@@ -212,13 +229,15 @@ void register_core(nb::module_& m) {
|
||||
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("connect",
|
||||
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
|
||||
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const Endpoint&, int, int)>(
|
||||
&Communicator::connect),
|
||||
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
nb::arg("localEndpoint"), nb::arg("remoteRank"), nb::arg("tag") = 0)
|
||||
.def("connect", [](Communicator* self, const EndpointConfig& localConfig, int remoteRank,
|
||||
int tag = 0) { return self->connect(localConfig, remoteRank, tag); })
|
||||
.def(
|
||||
"connect",
|
||||
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
|
||||
return self->connect(std::move(localConfig), remoteRank, tag);
|
||||
[](Communicator* self, int remoteRank, int tag, const EndpointConfig& localConfig) {
|
||||
return self->connect(localConfig, remoteRank, tag);
|
||||
},
|
||||
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
|
||||
.def(
|
||||
|
||||
Reference in New Issue
Block a user