mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 23:06:17 +00:00
Move barrier into setupNvlsChannels and clean up NVLS pipeline state
- setupNvlsChannels now takes the Communicator and barriers internally after binding all switch channels, replacing the explicit bootstrap()->barrier() previously done only in AllreduceNvlsPacket. - Demote nRanksPerIpcDomain_ / nBaseChannels_ to locals in AllreduceNvlsBlockPipeline and AllreduceNvlsWarpPipeline; they were never read outside initialize(). - Drive-by: pick up in-tree edits to switch_channel_device.hpp, executor.cc, communicator.hpp, and allreduce_rsag.cu. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -39,8 +39,7 @@ struct SwitchChannelDeviceHandle {
|
||||
|
||||
/// Vectorized multimem load+reduce. The optional `AccumT` template parameter selects the
|
||||
/// accumulator: when `AccumT == __half` and `VectorType` is an FP8 vector type, the
|
||||
/// `.acc::f16` variant of the instruction is used (faster but lower precision than the
|
||||
/// default FP32 accumulator). For all other types `AccumT` is ignored.
|
||||
/// `.acc::f16` variant of the instruction is used. For all other types `AccumT` is ignored.
|
||||
template <typename VectorType, typename AccumT = void>
|
||||
MSCCLPP_DEVICE_INLINE static VectorType multimemLoadReduce(VectorType* ptr) {
|
||||
VectorType val;
|
||||
|
||||
@@ -389,6 +389,7 @@ struct Executor::Impl {
|
||||
nvlsConnection->bindAllocatedMemory((CUdeviceptr)bufferInfo.first, bufferInfo.second);
|
||||
context.nvlsChannels.push_back(switchChannel);
|
||||
}
|
||||
this->comm->bootstrap()->barrier();
|
||||
}
|
||||
|
||||
void setupSemaphores(ExecutionContext& context, const ExecutionPlan& plan) {
|
||||
|
||||
@@ -60,6 +60,7 @@ struct Communicator::Impl {
|
||||
std::shared_ptr<Bootstrap> bootstrap_;
|
||||
std::shared_ptr<Context> context_;
|
||||
std::unordered_map<const BaseConnection*, ConnectionInfo> connectionInfos_;
|
||||
|
||||
// Temporary storage for the latest RecvItem of each {remoteRank, tag} pair.
|
||||
// The RecvItem is removed when it finishes or when getLastRecvItem observes that it is ready.
|
||||
std::unordered_map<std::pair<int, int>, std::shared_ptr<BaseRecvItem>, PairHash> lastRecvItems_;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
|
||||
#include "allreduce/allreduce_nvls_block_pipeline.hpp"
|
||||
@@ -177,15 +176,15 @@ struct NvlsBlockPipelineAdapter {
|
||||
|
||||
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
nSwitchChannels_ = 8;
|
||||
nRanksPerIpcDomain_ = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
// Per-peer channel allocation must hold up to 4 * nRanksPerIpcDomain entries (see kernel).
|
||||
nBaseChannels_ = std::max(64, 4 * nRanksPerIpcDomain_);
|
||||
int nBaseChannels = std::max(64, 4 * nRanksPerIpcDomain);
|
||||
this->conns_ = setupConnections(comm);
|
||||
// setup semaphores
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores =
|
||||
setupMemorySemaphores(comm, this->conns_, nBaseChannels_);
|
||||
setupMemorySemaphores(comm, this->conns_, nBaseChannels);
|
||||
// setup base memory channels
|
||||
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels_);
|
||||
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels);
|
||||
this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_);
|
||||
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
|
||||
}
|
||||
@@ -228,7 +227,7 @@ std::shared_ptr<void> AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar
|
||||
|
||||
// setup channels
|
||||
ctx->switchChannels =
|
||||
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
|
||||
setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
|
||||
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
@@ -82,8 +82,7 @@ void AllreduceNvlsPacket::initialize(std::shared_ptr<Communicator> comm) {
|
||||
int nSwitchChannels = 1;
|
||||
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
|
||||
this->switchChannels_ =
|
||||
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
|
||||
comm->bootstrap()->barrier();
|
||||
setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
|
||||
}
|
||||
|
||||
AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
|
||||
#include "allreduce/allreduce_nvls_warp_pipeline.hpp"
|
||||
@@ -141,15 +140,15 @@ struct NvlsWarpPipelineAdapter {
|
||||
|
||||
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
nSwitchChannels_ = NUM_NVLS_CONNECTION;
|
||||
nRanksPerIpcDomain_ = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
// Per-peer channel allocation must hold 2 * nBlocks entries; default nBlocks = 4 * nRanksPerIpcDomain.
|
||||
nBaseChannels_ = std::max(64, 8 * nRanksPerIpcDomain_);
|
||||
int nBaseChannels = std::max(64, 8 * nRanksPerIpcDomain);
|
||||
this->conns_ = setupConnections(comm);
|
||||
// setup semaphores
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores =
|
||||
setupMemorySemaphores(comm, this->conns_, nBaseChannels_);
|
||||
setupMemorySemaphores(comm, this->conns_, nBaseChannels);
|
||||
// setup base memory channels
|
||||
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels_);
|
||||
this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels);
|
||||
this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_);
|
||||
this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_);
|
||||
}
|
||||
@@ -192,7 +191,7 @@ std::shared_ptr<void> AllreduceNvlsWarpPipeline::initAllreduceContext(std::share
|
||||
|
||||
// setup channels
|
||||
ctx->switchChannels =
|
||||
setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
|
||||
setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_);
|
||||
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
|
||||
#include "allreduce/allreduce_nvls_zero_copy.hpp"
|
||||
#include "allreduce/common.hpp"
|
||||
@@ -195,11 +194,12 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
|
||||
|
||||
// setup channels
|
||||
ctx->switchChannels = setupNvlsChannels(this->nvlsConnections_, (void*)sendBasePtr, sendBytes, nSwitchChannels_);
|
||||
ctx->switchChannels =
|
||||
setupNvlsChannels(comm, this->nvlsConnections_, (void*)sendBasePtr, sendBytes, nSwitchChannels_);
|
||||
if (input != output) {
|
||||
auto nvlsOutConnections = this->nvlsOutConnections_;
|
||||
std::vector<mscclpp::SwitchChannel> outChannels =
|
||||
setupNvlsChannels(this->nvlsOutConnections_, (void*)recvBasePtr, recvBytes, nSwitchChannels_);
|
||||
setupNvlsChannels(comm, this->nvlsOutConnections_, (void*)recvBasePtr, recvBytes, nSwitchChannels_);
|
||||
ctx->switchChannels.insert(ctx->switchChannels.end(), outChannels.begin(), outChannels.end());
|
||||
}
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ struct AllreduceRsAgAdapter {
|
||||
|
||||
void AllreduceRsAg::initialize(std::shared_ptr<Communicator> comm) {
|
||||
this->conns_ = setupConnections(comm);
|
||||
nChannelsPerConnection_ = 128;
|
||||
nChannelsPerConnection_ = 64;
|
||||
comm_ = comm;
|
||||
// setup semaphores
|
||||
this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
|
||||
|
||||
@@ -6,12 +6,9 @@
|
||||
#include <algorithm>
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/switch_channel.hpp>
|
||||
|
||||
#include "logger.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<mscclpp::Communicator> comm, int rank,
|
||||
@@ -101,7 +98,8 @@ std::vector<std::shared_ptr<mscclpp::NvlsConnection>> setupNvlsConnections(std::
|
||||
return nvlsConnections;
|
||||
}
|
||||
|
||||
std::vector<mscclpp::SwitchChannel> setupNvlsChannels(std::vector<std::shared_ptr<mscclpp::NvlsConnection>> conns,
|
||||
std::vector<mscclpp::SwitchChannel> setupNvlsChannels(std::shared_ptr<mscclpp::Communicator> comm,
|
||||
std::vector<std::shared_ptr<mscclpp::NvlsConnection>> conns,
|
||||
void* buffer, size_t bufferSize, int nSwitchChannels) {
|
||||
std::vector<mscclpp::SwitchChannel> channels;
|
||||
|
||||
@@ -110,6 +108,8 @@ std::vector<mscclpp::SwitchChannel> setupNvlsChannels(std::vector<std::shared_pt
|
||||
mscclpp::SwitchChannel switchChannel = nvlsConnection->bindAllocatedMemory((CUdeviceptr)buffer, bufferSize);
|
||||
channels.push_back(switchChannel);
|
||||
}
|
||||
// Synchronize to make sure all ranks have their NVLS channels set up before any rank starts using them.
|
||||
comm->bootstrap()->barrier();
|
||||
return channels;
|
||||
}
|
||||
|
||||
|
||||
@@ -29,8 +29,6 @@ class AllreduceNvlsBlockPipeline : public AlgorithmBuilder {
|
||||
void* scratchBuffer_;
|
||||
size_t scratchBufferSize_;
|
||||
uint32_t nSwitchChannels_;
|
||||
int nRanksPerIpcDomain_ = 0;
|
||||
int nBaseChannels_ = 0;
|
||||
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> memoryChannelsDeviceHandle_;
|
||||
std::vector<BaseMemoryChannel> baseChannels_;
|
||||
std::vector<Connection> conns_;
|
||||
|
||||
@@ -29,8 +29,6 @@ class AllreduceNvlsWarpPipeline : public AlgorithmBuilder {
|
||||
void* scratchBuffer_;
|
||||
size_t scratchBufferSize_;
|
||||
uint32_t nSwitchChannels_;
|
||||
int nRanksPerIpcDomain_ = 0;
|
||||
int nBaseChannels_ = 0;
|
||||
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> memoryChannelsDeviceHandle_;
|
||||
std::vector<BaseMemoryChannel> baseChannels_;
|
||||
std::vector<Connection> conns_;
|
||||
|
||||
@@ -57,7 +57,8 @@ std::shared_ptr<DeviceHandle<MemoryChannel>> setupMemoryChannelDeviceHandles(
|
||||
std::vector<std::shared_ptr<NvlsConnection>> setupNvlsConnections(std::shared_ptr<Communicator> comm, size_t size,
|
||||
int numConnections);
|
||||
|
||||
std::vector<SwitchChannel> setupNvlsChannels(std::vector<std::shared_ptr<NvlsConnection>> conns, void* buffer,
|
||||
std::vector<SwitchChannel> setupNvlsChannels(std::shared_ptr<Communicator> comm,
|
||||
std::vector<std::shared_ptr<NvlsConnection>> conns, void* buffer,
|
||||
size_t bufferSize, int nSwitchChannels);
|
||||
|
||||
std::shared_ptr<DeviceHandle<SwitchChannel>> setupNvlsChannelDeviceHandles(
|
||||
|
||||
Reference in New Issue
Block a user