Update EndpointConfig interfaces (#651)

* Separate IB-specific options into a nested struct
* Enable `connect()` by an `Endpoint`, not only by `EndpointConfig`
* Other minor changes
This commit is contained in:
Changho Hwang
2025-10-22 10:39:39 -07:00
committed by GitHub
parent 610db6f023
commit 200cdf946e
8 changed files with 153 additions and 110 deletions

View File

@@ -374,41 +374,53 @@ struct Device {
int id;
};
/// Used to configure an endpoint.
/// Configuration for creating communication endpoints.
struct EndpointConfig {
static const int DefaultMaxCqSize = 1024;
static const int DefaultMaxCqPollNum = 1;
static const int DefaultMaxSendWr = 8192;
static const int DefaultMaxWrPerSend = 64;
/// InfiniBand-specific configuration options that control queue pair behavior and performance characteristics.
/// These settings are only used when the transport is an InfiniBand type (IB0-IB7); they are ignored for other
/// transports.
struct Ib {
static const int DefaultMaxCqSize = 1024;
static const int DefaultMaxCqPollNum = 1;
static const int DefaultMaxSendWr = 8192;
static const int DefaultMaxWrPerSend = 64;
/// Maximum size of the completion queue.
int maxCqSize;
/// Maximum number of completion queue polls per operation.
int maxCqPollNum;
/// Maximum number of outstanding send work requests.
int maxSendWr;
/// Maximum number of work requests per send operation.
int maxWrPerSend;
/// Constructor.
/// @param maxCqSize Maximum completion queue size.
/// @param maxCqPollNum Maximum completion queue poll count.
/// @param maxSendWr Maximum outstanding send work requests.
/// @param maxWrPerSend Maximum work requests per send operation.
Ib(int maxCqSize = DefaultMaxCqSize, int maxCqPollNum = DefaultMaxCqPollNum, int maxSendWr = DefaultMaxSendWr,
int maxWrPerSend = DefaultMaxWrPerSend)
: maxCqSize(maxCqSize), maxCqPollNum(maxCqPollNum), maxSendWr(maxSendWr), maxWrPerSend(maxWrPerSend) {}
};
/// Communication transport type (e.g., CudaIpc, IB0-IB7, Ethernet).
Transport transport;
/// Target device for the endpoint (GPU or CPU with optional device ID).
Device device;
int ibMaxCqSize;
int ibMaxCqPollNum;
int ibMaxSendWr;
int ibMaxWrPerSend;
/// Maximum number of write requests that can be queued (-1 for default).
int maxWriteQueueSize;
/// InfiniBand-specific options (used only for Transport::IBx).
Ib ib;
/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param transport The transport to use.
/// @param device The device to use.
/// @param ibMaxCqSize The maximum completion queue size.
/// @param ibMaxCqPollNum The maximum completion queue poll number.
/// @param ibMaxSendWr The maximum send work requests.
/// @param ibMaxWrPerSend The maximum work requests per send.
/// @param maxWriteQueueSize The maximum write queue size.
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
int maxWriteQueueSize = -1)
: transport(transport),
device(device),
ibMaxCqSize(ibMaxCqSize),
ibMaxCqPollNum(ibMaxCqPollNum),
ibMaxSendWr(ibMaxSendWr),
ibMaxWrPerSend(ibMaxWrPerSend),
maxWriteQueueSize(maxWriteQueueSize) {}
/// Constructs endpoint configuration with specified transport, device, and optional settings.
/// @param transport Communication transport to use.
/// @param device Target device for the endpoint.
/// @param maxWriteQueueSize Maximum write queue size (-1 for system default).
/// @param ib IB-specific configuration.
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU, int maxWriteQueueSize = -1,
Ib ib = {})
: transport(transport), device(device), maxWriteQueueSize(maxWriteQueueSize), ib(ib) {}
};
class Context;
@@ -423,6 +435,10 @@ class Endpoint {
/// Constructor.
Endpoint() = default;
/// Get the configuration used to create the endpoint.
/// @return The configuration used to create the endpoint.
const EndpointConfig& config() const;
/// Get the transport used.
/// @return The transport used.
Transport transport() const;
@@ -685,9 +701,9 @@ class Semaphore {
std::shared_ptr<Impl> pimpl_;
};
/// Deprecated.
template <typename T>
using NonblockingFuture [[deprecated("Use std::shared_future instead. This will be removed in a future release.")]] =
std::shared_future<T>;
using NonblockingFuture = std::shared_future<T>;
/// A class that sets up all registered memories and connections between processes.
///
@@ -853,12 +869,20 @@ class Communicator {
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
/// back to back.
///
/// @param localConfig The configuration for the local endpoint.
/// @param localEndpoint The local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @return A future of shared pointer to the connection.
///
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0);
std::shared_future<std::shared_ptr<Connection>> connect(const Endpoint& localEndpoint, int remoteRank, int tag = 0);
/// Connect to a remote rank. Wrapper of `connect(localEndpoint, remoteRank, tag)`.
/// @param localConfig The configuration for the local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @return A future of shared pointer to the connection.
std::shared_future<std::shared_ptr<Connection>> connect(const EndpointConfig& localConfig, int remoteRank,
int tag = 0);
[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
shared_future<std::shared_ptr<Connection>>

View File

@@ -124,6 +124,17 @@ void register_core(nb::module_& m) {
.def_rw("id", &Device::id)
.def("__str__", [](const Device& self) { return std::to_string(self); });
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
.def(nb::init<>())
.def(nb::init<int, int, int, int>(), nb::arg("maxCqSize") = EndpointConfig::Ib::DefaultMaxCqSize,
nb::arg("maxCqPollNum") = EndpointConfig::Ib::DefaultMaxCqPollNum,
nb::arg("maxSendWr") = EndpointConfig::Ib::DefaultMaxSendWr,
nb::arg("maxWrPerSend") = EndpointConfig::Ib::DefaultMaxWrPerSend)
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
.def(nb::init<>())
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
@@ -158,17 +169,23 @@ void register_core(nb::module_& m) {
nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, Device, int, int, int, int, int>(), nb::arg("transport"), nb::arg("device"),
nb::arg("ibMaxCqSize") = EndpointConfig::DefaultMaxCqSize,
nb::arg("ibMaxCqPollNum") = EndpointConfig::DefaultMaxCqPollNum,
nb::arg("ibMaxSendWr") = EndpointConfig::DefaultMaxSendWr,
nb::arg("ibMaxWrPerSend") = EndpointConfig::DefaultMaxWrPerSend, nb::arg("maxWriteQueueSize") = -1)
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
nb::arg("maxWriteQueueSize") = -1, nb::arg("ib") = EndpointConfig::Ib{})
.def_rw("transport", &EndpointConfig::transport)
.def_rw("device", &EndpointConfig::device)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend)
.def_rw("ib", &EndpointConfig::ib)
.def_prop_rw(
"ib_max_cq_size", [](EndpointConfig& self) { return self.ib.maxCqSize; },
[](EndpointConfig& self, int v) { self.ib.maxCqSize = v; })
.def_prop_rw(
"ib_max_cq_poll_num", [](EndpointConfig& self) { return self.ib.maxCqPollNum; },
[](EndpointConfig& self, int v) { self.ib.maxCqPollNum = v; })
.def_prop_rw(
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
.def_prop_rw(
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
nb::class_<Context>(m, "Context")
@@ -212,13 +229,15 @@ void register_core(nb::module_& m) {
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("connect",
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const Endpoint&, int, int)>(
&Communicator::connect),
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
nb::arg("localEndpoint"), nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("connect", [](Communicator* self, const EndpointConfig& localConfig, int remoteRank,
int tag = 0) { return self->connect(localConfig, remoteRank, tag); })
.def(
"connect",
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
return self->connect(std::move(localConfig), remoteRank, tag);
[](Communicator* self, int remoteRank, int tag, const EndpointConfig& localConfig) {
return self->connect(localConfig, remoteRank, tag);
},
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def(

View File

@@ -99,41 +99,44 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
return shared_future;
}
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(EndpointConfig localConfig,
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const Endpoint& localEndpoint,
int remoteRank, int tag) {
auto localEndpoint = context()->createEndpoint(localConfig);
if (remoteRank == bootstrap()->getRank()) {
// Connection to self
auto remoteEndpoint = context()->createEndpoint(localConfig);
auto remoteEndpoint = context()->createEndpoint(localEndpoint.config());
auto connection = context()->connect(localEndpoint, remoteEndpoint);
std::promise<std::shared_ptr<Connection>> promise;
promise.set_value(connection);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return std::shared_future<std::shared_ptr<Connection>>(std::move(promise.get_future()));
return std::shared_future<std::shared_ptr<Connection>>(promise.get_future());
}
bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);
auto future =
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
localEndpoint = std::move(localEndpoint)]() mutable {
if (lastRecvItem) {
// Recursive call to the previous receive items
lastRecvItem->wait();
}
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = context()->connect(localEndpoint, remoteEndpoint);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return connection;
});
auto future = std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint,
lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() mutable {
if (lastRecvItem) {
// Recursive call to the previous receive items
lastRecvItem->wait();
}
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = context()->connect(localEndpoint, remoteEndpoint);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return connection;
});
auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move(future));
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<std::shared_ptr<Connection>>>(shared_future));
return shared_future;
}
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const EndpointConfig& localConfig,
int remoteRank, int tag) {
auto localEndpoint = context()->createEndpoint(localConfig);
return connect(localEndpoint, remoteRank, tag);
}
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
EndpointConfig localConfig) {
return connect(localConfig, remoteRank, tag);

View File

@@ -167,9 +167,6 @@ IBConnection::IBConnection(std::shared_ptr<Context> context, const Endpoint& loc
transport_(localEndpoint.transport()),
remoteTransport_(remoteEndpoint.transport()),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
if (maxWriteQueueSize_ == -1) {
maxWriteQueueSize_ = EndpointConfig::DefaultMaxCqSize;
}
qp_ = getImpl(localEndpoint).ibQp_;
qp_.lock()->rtr(getImpl(remoteEndpoint).ibQpInfo_);
qp_.lock()->rts();

View File

@@ -4,6 +4,7 @@
#include "context.hpp"
#include <mscclpp/env.hpp>
#include <sstream>
#include "api.h"
#include "connection.hpp"
@@ -76,21 +77,21 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &loc
if (remoteEndpoint.device().type == DeviceType::GPU && remoteEndpoint.device().id < 0) {
throw Error("No GPU device ID provided for remote endpoint", ErrorCode::InvalidUsage);
}
auto localTransport = localEndpoint.transport();
auto remoteTransport = remoteEndpoint.transport();
if (localTransport != remoteTransport &&
!(AllIBTransports.has(localTransport) && AllIBTransports.has(remoteTransport))) {
std::stringstream ss;
ss << "Transport mismatch between local (" << std::to_string(localTransport) << ") and remote ("
<< std::to_string(remoteEndpoint.transport()) << ") endpoints";
throw Error(ss.str(), ErrorCode::InvalidUsage);
}
std::shared_ptr<Connection> conn;
if (localEndpoint.transport() == Transport::CudaIpc) {
if (remoteEndpoint.transport() != Transport::CudaIpc) {
throw Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
}
if (localTransport == Transport::CudaIpc) {
conn = std::make_shared<CudaIpcConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
} else if (AllIBTransports.has(localEndpoint.transport())) {
if (!AllIBTransports.has(remoteEndpoint.transport())) {
throw Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
}
} else if (AllIBTransports.has(localTransport)) {
conn = std::make_shared<IBConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
} else if (localEndpoint.transport() == Transport::Ethernet) {
if (remoteEndpoint.transport() != Transport::Ethernet) {
throw Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
}
} else if (localTransport == Transport::Ethernet) {
conn = std::make_shared<EthernetConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
} else {
throw Error("Unsupported transport", ErrorCode::InternalError);

View File

@@ -13,21 +13,21 @@
namespace mscclpp {
Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
: transport_(config.transport),
device_(config.device),
hostHash_(getHostHash()),
pidHash_(getPidHash()),
maxWriteQueueSize_(config.maxWriteQueueSize) {
if (device_.type == DeviceType::GPU && device_.id < 0) {
MSCCLPP_CUDATHROW(cudaGetDevice(&(device_.id)));
Endpoint::Impl::Impl(const EndpointConfig& config, Context::Impl& contextImpl)
: config_(config), hostHash_(getHostHash()), pidHash_(getPidHash()) {
if (config_.device.type == DeviceType::GPU && config_.device.id < 0) {
MSCCLPP_CUDATHROW(cudaGetDevice(&(config_.device.id)));
}
if (AllIBTransports.has(transport_)) {
if (AllIBTransports.has(config_.transport)) {
ibLocal_ = true;
ibQp_ = contextImpl.getIbContext(transport_)
->createQp(config.ibMaxCqSize, config.ibMaxCqPollNum, config.ibMaxSendWr, 0, config.ibMaxWrPerSend);
if (config_.maxWriteQueueSize <= 0) {
config_.maxWriteQueueSize = config_.ib.maxCqSize;
}
ibQp_ =
contextImpl.getIbContext(config_.transport)
->createQp(config_.ib.maxCqSize, config_.ib.maxCqPollNum, config_.ib.maxSendWr, 0, config_.ib.maxWrPerSend);
ibQpInfo_ = ibQp_->getInfo();
} else if (transport_ == Transport::Ethernet) {
} else if (config_.transport == Transport::Ethernet) {
// Configuring Ethernet Interfaces
abortFlag_ = 0;
int ret = FindInterfaces(netIfName_, &socketAddress_, MAX_IF_NAME_SIZE, 1);
@@ -42,41 +42,42 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
Endpoint::Impl::Impl(const std::vector<char>& serialization) {
auto it = serialization.begin();
it = detail::deserialize(it, transport_);
it = detail::deserialize(it, device_);
it = detail::deserialize(it, config_);
it = detail::deserialize(it, hostHash_);
it = detail::deserialize(it, pidHash_);
if (AllIBTransports.has(transport_)) {
if (AllIBTransports.has(config_.transport)) {
ibLocal_ = false;
it = detail::deserialize(it, ibQpInfo_);
}
if (transport_ == Transport::Ethernet) {
} else if (config_.transport == Transport::Ethernet) {
it = detail::deserialize(it, socketAddress_);
}
if (it != serialization.end()) {
throw Error("Endpoint deserialization failed", ErrorCode::Aborted);
}
}
MSCCLPP_API_CPP Endpoint::Endpoint(std::shared_ptr<Endpoint::Impl> pimpl) : pimpl_(pimpl) {}
MSCCLPP_API_CPP Transport Endpoint::transport() const { return pimpl_->transport_; }
MSCCLPP_API_CPP const EndpointConfig& Endpoint::config() const { return pimpl_->config_; }
MSCCLPP_API_CPP const Device& Endpoint::device() const { return pimpl_->device_; }
MSCCLPP_API_CPP Transport Endpoint::transport() const { return pimpl_->config_.transport; }
MSCCLPP_API_CPP const Device& Endpoint::device() const { return pimpl_->config_.device; }
MSCCLPP_API_CPP uint64_t Endpoint::hostHash() const { return pimpl_->hostHash_; }
MSCCLPP_API_CPP uint64_t Endpoint::pidHash() const { return pimpl_->pidHash_; }
MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() const { return pimpl_->maxWriteQueueSize_; }
MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() const { return pimpl_->config_.maxWriteQueueSize; }
MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() const {
std::vector<char> data;
detail::serialize(data, pimpl_->transport_);
detail::serialize(data, pimpl_->device_);
detail::serialize(data, pimpl_->config_);
detail::serialize(data, pimpl_->hostHash_);
detail::serialize(data, pimpl_->pidHash_);
if (AllIBTransports.has(pimpl_->transport_)) {
if (AllIBTransports.has(pimpl_->config_.transport)) {
detail::serialize(data, pimpl_->ibQpInfo_);
}
if ((pimpl_->transport_) == Transport::Ethernet) {
} else if (pimpl_->config_.transport == Transport::Ethernet) {
detail::serialize(data, pimpl_->socketAddress_);
}
return data;

View File

@@ -36,7 +36,7 @@ class RecvItem : public BaseRecvItem {
class LocalRecvMemory {
public:
LocalRecvMemory() : future_(std::move(promise_.get_future())) {}
LocalRecvMemory() : future_(promise_.get_future()) {}
void set(RegisteredMemory memory) { promise_.set_value(std::move(memory)); }

View File

@@ -15,14 +15,12 @@
namespace mscclpp {
struct Endpoint::Impl {
Impl(EndpointConfig config, Context::Impl& contextImpl);
Impl(const EndpointConfig& config, Context::Impl& contextImpl);
Impl(const std::vector<char>& serialization);
Transport transport_;
Device device_;
EndpointConfig config_;
uint64_t hostHash_;
uint64_t pidHash_;
int maxWriteQueueSize_;
// The following are only used for IB and are undefined for other transports.
bool ibLocal_;