Move barrier into setupNvlsChannels and clean up NVLS pipeline state

- setupNvlsChannels now takes the Communicator and barriers internally
  after binding all switch channels, replacing the explicit
  bootstrap()->barrier() previously done only in AllreduceNvlsPacket.
- Demote nRanksPerIpcDomain_ / nBaseChannels_ to locals in
  AllreduceNvlsBlockPipeline and AllreduceNvlsWarpPipeline; they were
  never read outside initialize().
- Drive-by: pick up in-tree edits to switch_channel_device.hpp,
  executor.cc, communicator.hpp, and allreduce_rsag.cu.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-18 20:50:01 +00:00
parent 18d37379d2
commit 4db71b93b7
12 changed files with 24 additions and 29 deletions

View File

@@ -39,8 +39,7 @@ struct SwitchChannelDeviceHandle {
/// Vectorized multimem load+reduce. The optional `AccumT` template parameter selects the
/// accumulator: when `AccumT == __half` and `VectorType` is an FP8 vector type, the
/// `.acc::f16` variant of the instruction is used (faster but lower precision than the
/// default FP32 accumulator). For all other types `AccumT` is ignored.
/// `.acc::f16` variant of the instruction is used. For all other types `AccumT` is ignored.
template <typename VectorType, typename AccumT = void>
MSCCLPP_DEVICE_INLINE static VectorType multimemLoadReduce(VectorType* ptr) {
VectorType val;

View File

@@ -389,6 +389,7 @@ struct Executor::Impl {
nvlsConnection->bindAllocatedMemory((CUdeviceptr)bufferInfo.first, bufferInfo.second);
context.nvlsChannels.push_back(switchChannel);
}
this->comm->bootstrap()->barrier();
}
void setupSemaphores(ExecutionContext& context, const ExecutionPlan& plan) {

View File

@@ -60,6 +60,7 @@ struct Communicator::Impl {
std::shared_ptr<Bootstrap> bootstrap_;
std::shared_ptr<Context> context_;
std::unordered_map<const BaseConnection*, ConnectionInfo> connectionInfos_;
// Temporary storage for the latest RecvItem of each {remoteRank, tag} pair.
// The RecvItem is removed when it finishes or when getLastRecvItem observes that it is ready.
std::unordered_map<std::pair<int, int>, std::shared_ptr<BaseRecvItem>, PairHash> lastRecvItems_;

View File

@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <algorithm>
#include <mscclpp/algorithm.hpp>
#include "allreduce/allreduce_nvls_block_pipeline.hpp"
@@ -177,15 +176,15 @@ struct NvlsBlockPipelineAdapter {
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = 8;
nRanksPerIpcDomain_ = comm->bootstrap()->getNranksPerIpcDomain();
int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
// Per-peer channel allocation must hold up to 4 * nRanksPerIpcDomain entries (see kernel).
nBaseChannels_ = std::max(64, 4 * nRanksPerIpcDomain_);
int nBaseChannels = std::max(64, 4 * nRanksPerIpcDomain);
this->conns_ = setupConnections(comm);
// setup semaphores
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores =
setupMemorySemaphores(comm, this->conns_, nBaseChannels_);
setupMemorySemaphores(comm, this->conns_, nBaseChannels);
// setup base memory channels
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels_);
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels);
this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_);
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
}
@@ -228,7 +227,7 @@ std::shared_ptr<void> AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar
// setup channels
ctx->switchChannels =
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
return ctx;
}

View File

@@ -82,8 +82,7 @@ void AllreduceNvlsPacket::initialize(std::shared_ptr<Communicator> comm) {
int nSwitchChannels = 1;
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
this->switchChannels_ =
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
comm->bootstrap()->barrier();
setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
}
AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {

View File

@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <algorithm>
#include <mscclpp/algorithm.hpp>
#include "allreduce/allreduce_nvls_warp_pipeline.hpp"
@@ -141,15 +140,15 @@ struct NvlsWarpPipelineAdapter {
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = NUM_NVLS_CONNECTION;
nRanksPerIpcDomain_ = comm->bootstrap()->getNranksPerIpcDomain();
int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
// Per-peer channel allocation must hold 2 * nBlocks entries; default nBlocks = 4 * nRanksPerIpcDomain.
nBaseChannels_ = std::max(64, 8 * nRanksPerIpcDomain_);
int nBaseChannels = std::max(64, 8 * nRanksPerIpcDomain);
this->conns_ = setupConnections(comm);
// setup semaphores
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores =
setupMemorySemaphores(comm, this->conns_, nBaseChannels_);
setupMemorySemaphores(comm, this->conns_, nBaseChannels);
// setup base memory channels
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels_);
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels);
this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_);
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
}
@@ -192,7 +191,7 @@ std::shared_ptr<void> AllreduceNvlsWarpPipeline::initAllreduceContext(std::share
// setup channels
ctx->switchChannels =
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
return ctx;
}

View File

@@ -2,7 +2,6 @@
// Licensed under the MIT License.
#include <mscclpp/core.hpp>
#include <mscclpp/errors.hpp>
#include "allreduce/allreduce_nvls_zero_copy.hpp"
#include "allreduce/common.hpp"
@@ -195,11 +194,12 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
// setup channels
ctx->switchChannels = setupNvlsChannels(this->nvlsConnections_, (void*)sendBasePtr, sendBytes, nSwitchChannels_);
ctx->switchChannels =
setupNvlsChannels(comm, this->nvlsConnections_, (void*)sendBasePtr, sendBytes, nSwitchChannels_);
if (input != output) {
auto nvlsOutConnections = this->nvlsOutConnections_;
std::vector<mscclpp::SwitchChannel> outChannels =
setupNvlsChannels(this->nvlsOutConnections_, (void*)recvBasePtr, recvBytes, nSwitchChannels_);
setupNvlsChannels(comm, this->nvlsOutConnections_, (void*)recvBasePtr, recvBytes, nSwitchChannels_);
ctx->switchChannels.insert(ctx->switchChannels.end(), outChannels.begin(), outChannels.end());
}

View File

@@ -144,7 +144,7 @@ struct AllreduceRsAgAdapter {
void AllreduceRsAg::initialize(std::shared_ptr<Communicator> comm) {
this->conns_ = setupConnections(comm);
nChannelsPerConnection_ = 128;
nChannelsPerConnection_ = 64;
comm_ = comm;
// setup semaphores
this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);

View File

@@ -6,12 +6,9 @@
#include <algorithm>
#include <mscclpp/algorithm.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/errors.hpp>
#include <mscclpp/memory_channel.hpp>
#include <mscclpp/switch_channel.hpp>
#include "logger.hpp"
namespace mscclpp {
namespace collective {
std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<mscclpp::Communicator> comm, int rank,
@@ -101,7 +98,8 @@ std::vector<std::shared_ptr<mscclpp::NvlsConnection>> setupNvlsConnections(std::
return nvlsConnections;
}
std::vector<mscclpp::SwitchChannel> setupNvlsChannels(std::vector<std::shared_ptr<mscclpp::NvlsConnection>> conns,
std::vector<mscclpp::SwitchChannel> setupNvlsChannels(std::shared_ptr<mscclpp::Communicator> comm,
std::vector<std::shared_ptr<mscclpp::NvlsConnection>> conns,
void* buffer, size_t bufferSize, int nSwitchChannels) {
std::vector<mscclpp::SwitchChannel> channels;
@@ -110,6 +108,8 @@ std::vector<mscclpp::SwitchChannel> setupNvlsChannels(std::vector<std::shared_pt
mscclpp::SwitchChannel switchChannel = nvlsConnection->bindAllocatedMemory((CUdeviceptr)buffer, bufferSize);
channels.push_back(switchChannel);
}
// Synchronize to make sure all ranks have their NVLS channels set up before any rank starts using them.
comm->bootstrap()->barrier();
return channels;
}

View File

@@ -29,8 +29,6 @@ class AllreduceNvlsBlockPipeline : public AlgorithmBuilder {
void* scratchBuffer_;
size_t scratchBufferSize_;
uint32_t nSwitchChannels_;
int nRanksPerIpcDomain_ = 0;
int nBaseChannels_ = 0;
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> memoryChannelsDeviceHandle_;
std::vector<BaseMemoryChannel> baseChannels_;
std::vector<Connection> conns_;

View File

@@ -29,8 +29,6 @@ class AllreduceNvlsWarpPipeline : public AlgorithmBuilder {
void* scratchBuffer_;
size_t scratchBufferSize_;
uint32_t nSwitchChannels_;
int nRanksPerIpcDomain_ = 0;
int nBaseChannels_ = 0;
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> memoryChannelsDeviceHandle_;
std::vector<BaseMemoryChannel> baseChannels_;
std::vector<Connection> conns_;

View File

@@ -57,7 +57,8 @@ std::shared_ptr<DeviceHandle<MemoryChannel>> setupMemoryChannelDeviceHandles(
std::vector<std::shared_ptr<NvlsConnection>> setupNvlsConnections(std::shared_ptr<Communicator> comm, size_t size,
int numConnections);
std::vector<SwitchChannel> setupNvlsChannels(std::vector<std::shared_ptr<NvlsConnection>> conns, void* buffer,
std::vector<SwitchChannel> setupNvlsChannels(std::shared_ptr<Communicator> comm,
std::vector<std::shared_ptr<NvlsConnection>> conns, void* buffer,
size_t bufferSize, int nSwitchChannels);
std::shared_ptr<DeviceHandle<SwitchChannel>> setupNvlsChannelDeviceHandles(