Merge branch 'main' into rjsouza/nvls-allgather-pr

This commit is contained in:
Empyreus
2026-06-23 23:27:24 +00:00
7 changed files with 366 additions and 115 deletions

View File

@@ -95,6 +95,7 @@ struct hash<mscclpp::DeviceExecutionPlanKey> {
namespace {
auto hasIBDevices = []() { return mscclpp::getIBDeviceCount() > 0; };
// TODO(binyli): Need to add NVL domain check.
auto useIB = [](int rank1, int rank2, int nranksPerNode) {
if (mscclpp::env()->forceDisableIb) return false;
bool inSameNode = rank1 / nranksPerNode == rank2 / nranksPerNode;
@@ -110,7 +111,7 @@ namespace mscclpp {
struct ExecutionContext {
std::shared_ptr<ProxyService> proxyService;
std::unordered_map<int, Connection> connections;
std::vector<Connection> connections;
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections;
MemoryId localMemoryIdBegin = MemoryId(0);
@@ -122,8 +123,6 @@ struct ExecutionContext {
// local registered memories to keep resources alive
std::vector<mscclpp::RegisteredMemory> localRegisteredMemories;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
std::vector<mscclpp::BaseMemoryChannel> memoryChannels;
std::vector<mscclpp::BasePortChannel> portChannels;
std::vector<mscclpp::SwitchChannel> nvlsChannels;
@@ -267,15 +266,36 @@ struct Executor::Impl {
}
};
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers();
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
for (int peer : connectedPeers) {
Transport transport =
!useIB(rank, peer, this->nranksPerNode) ? Transport::CudaIpc : IBs[rank % this->nranksPerNode];
connectionFutures.push_back(this->comm->connect(transport, peer));
// Create one connection (unique QP) per channel entry. Each channel gets its own
// QP — no shared connections.
// Use per-peer tag counters so that matched connections between pairs of ranks use
// the same tag, regardless of the order peers appear in each rank's connected_to list.
std::unordered_map<int, int> peerTagCounters;
Transport ibTransport = IBs[rank % this->nranksPerNode];
std::vector<std::shared_future<Connection>> connFutures;
for (ChannelType channelType : {ChannelType::MEMORY, ChannelType::PORT}) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
for (const auto& info : channelInfos) {
for (int peer : info.connectedPeers) {
Transport transport = channelType == ChannelType::PORT && useIB(rank, peer, this->nranksPerNode)
? ibTransport
: Transport::CudaIpc;
connFutures.push_back(this->comm->connect(transport, peer, peerTagCounters[peer]++));
}
}
channelInfos = plan.impl_->getUnpairedChannelInfos(nranks, channelType);
for (const auto& info : channelInfos) {
for (int peer : info.connectedPeers) {
Transport transport = channelType == ChannelType::PORT && useIB(rank, peer, this->nranksPerNode)
? ibTransport
: Transport::CudaIpc;
connFutures.push_back(this->comm->connect(transport, peer, peerTagCounters[peer]++));
}
}
}
for (size_t i = 0; i < connectionFutures.size(); i++) {
context.connections[connectedPeers[i]] = connectionFutures[i].get();
for (auto& future : connFutures) {
context.connections.push_back(future.get());
}
std::vector<NvlsInfo> nvlsInfos = plan.impl_->nvlsInfos.at(rank);
@@ -329,10 +349,11 @@ struct Executor::Impl {
std::vector<std::shared_future<Semaphore>> futureProxySemaphores;
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
std::vector<mscclpp::SemaphoreId> proxySemaphores;
int connIdx = 0;
auto processChannelInfos = [&](std::vector<ChannelInfo>& channelInfos) {
for (ChannelInfo& info : channelInfos) {
for (int peer : info.connectedPeers) {
auto connection = context.connections.at(peer);
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
auto& connection = context.connections[connIdx++];
if (info.channelType == ChannelType::MEMORY) {
futureMemorySemaphores.push_back(this->comm->buildSemaphore(
connection, this->comm->remoteRankOf(connection), this->comm->tagOf(connection)));
@@ -361,18 +382,15 @@ struct Executor::Impl {
proxySemaphores.push_back(context.proxyService->addSemaphore(sem.get()));
}
context.memorySemaphores = std::move(memorySemaphores);
context.proxySemaphores = std::move(proxySemaphores);
for (ChannelType channelType : channelTypes) {
std::vector<ChannelInfo> channelInfos = plan.impl_->getChannelInfos(channelType);
int index = 0;
for (ChannelInfo& info : channelInfos) {
for (size_t i = 0; i < info.connectedPeers.size(); i++) {
if (channelType == ChannelType::MEMORY) {
context.memoryChannels.emplace_back(context.memorySemaphores[index++]);
context.memoryChannels.emplace_back(memorySemaphores[index++]);
} else if (channelType == ChannelType::PORT) {
context.portChannels.emplace_back(context.proxyService->basePortChannel(context.proxySemaphores[index++]));
context.portChannels.emplace_back(context.proxyService->basePortChannel(proxySemaphores[index++]));
}
}
}

View File

@@ -174,11 +174,11 @@ MSCCLPP_DEVICE_INLINE void handlePut(const Operation& op, void* input, void* out
uint32_t dstOffset =
dstOffsets[tid] + getOffset<ReuseScratch>(portChannelBufferTypes_[op.outputBufferRefs[tid].id], offset);
uint32_t srcOffset = srcOffsets[tid] + getOffset<ReuseScratch>(op.inputBufferRefs[tid].type, offset);
if constexpr (PutWithSignal) {
portChannels_[channelIndexes[tid]].putWithSignal(dstMemoryId, dstOffset, srcMemoryId, srcOffset, size);
} else if constexpr (PutWithSignalAndFlush) {
if constexpr (PutWithSignalAndFlush) {
portChannels_[channelIndexes[tid]].putWithSignalAndFlush(dstMemoryId, (uint64_t)dstOffset, srcMemoryId,
(uint64_t)srcOffsets, size);
(uint64_t)srcOffset, size);
} else if constexpr (PutWithSignal) {
portChannels_[channelIndexes[tid]].putWithSignal(dstMemoryId, dstOffset, srcMemoryId, srcOffset, size);
} else {
portChannels_[channelIndexes[tid]].put(dstMemoryId, dstOffset, srcMemoryId, srcOffset, size);
}