diff --git a/include/mscclpp/switch_channel_device.hpp b/include/mscclpp/switch_channel_device.hpp index df22bd3a..fcdd7fdd 100644 --- a/include/mscclpp/switch_channel_device.hpp +++ b/include/mscclpp/switch_channel_device.hpp @@ -39,8 +39,7 @@ struct SwitchChannelDeviceHandle { /// Vectorized multimem load+reduce. The optional `AccumT` template parameter selects the /// accumulator: when `AccumT == __half` and `VectorType` is an FP8 vector type, the - /// `.acc::f16` variant of the instruction is used (faster but lower precision than the - /// default FP32 accumulator). For all other types `AccumT` is ignored. + /// `.acc::f16` variant of the instruction is used. For all other types `AccumT` is ignored. template MSCCLPP_DEVICE_INLINE static VectorType multimemLoadReduce(VectorType* ptr) { VectorType val; diff --git a/src/core/executor/executor.cc b/src/core/executor/executor.cc index fcecc4dd..15c6af4e 100644 --- a/src/core/executor/executor.cc +++ b/src/core/executor/executor.cc @@ -389,6 +389,7 @@ struct Executor::Impl { nvlsConnection->bindAllocatedMemory((CUdeviceptr)bufferInfo.first, bufferInfo.second); context.nvlsChannels.push_back(switchChannel); } + this->comm->bootstrap()->barrier(); } void setupSemaphores(ExecutionContext& context, const ExecutionPlan& plan) { diff --git a/src/core/include/communicator.hpp b/src/core/include/communicator.hpp index 333cc982..f15e20f7 100644 --- a/src/core/include/communicator.hpp +++ b/src/core/include/communicator.hpp @@ -60,6 +60,7 @@ struct Communicator::Impl { std::shared_ptr bootstrap_; std::shared_ptr context_; std::unordered_map connectionInfos_; + // Temporary storage for the latest RecvItem of each {remoteRank, tag} pair. // The RecvItem is removed when it finishes or when getLastRecvItem observes that it is ready. std::unordered_map, std::shared_ptr, PairHash> lastRecvItems_; diff --git a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu index 347ce8b4..04c7f8c9 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -#include #include #include "allreduce/allreduce_nvls_block_pipeline.hpp" @@ -177,15 +176,15 @@ struct NvlsBlockPipelineAdapter { void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr comm) { nSwitchChannels_ = 8; - nRanksPerIpcDomain_ = comm->bootstrap()->getNranksPerIpcDomain(); + int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // Per-peer channel allocation must hold up to 4 * nRanksPerIpcDomain entries (see kernel). - nBaseChannels_ = std::max(64, 4 * nRanksPerIpcDomain_); + int nBaseChannels = std::max(64, 4 * nRanksPerIpcDomain); this->conns_ = setupConnections(comm); // setup semaphores std::vector> memorySemaphores = - setupMemorySemaphores(comm, this->conns_, nBaseChannels_); + setupMemorySemaphores(comm, this->conns_, nBaseChannels); // setup base memory channels - this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels_); + this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels); this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_); this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_); } @@ -228,7 +227,7 @@ std::shared_ptr AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar // setup channels ctx->switchChannels = - setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_); + setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_); ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels); return ctx; } diff --git a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu index f16e8b05..1918eef1 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu @@ -82,8 +82,7 @@ void AllreduceNvlsPacket::initialize(std::shared_ptr comm) { int nSwitchChannels = 1; this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels); this->switchChannels_ = - setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels); - comm->bootstrap()->barrier(); + setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels); } AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) { diff --git a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu index ba447d32..d5bbb2e7 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -#include #include #include "allreduce/allreduce_nvls_warp_pipeline.hpp" @@ -141,15 +140,15 @@ struct NvlsWarpPipelineAdapter { void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr comm) { nSwitchChannels_ = NUM_NVLS_CONNECTION; - nRanksPerIpcDomain_ = comm->bootstrap()->getNranksPerIpcDomain(); + int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // Per-peer channel allocation must hold 2 * nBlocks entries; default nBlocks = 4 * nRanksPerIpcDomain. - nBaseChannels_ = std::max(64, 8 * nRanksPerIpcDomain_); + int nBaseChannels = std::max(64, 8 * nRanksPerIpcDomain); this->conns_ = setupConnections(comm); // setup semaphores std::vector> memorySemaphores = - setupMemorySemaphores(comm, this->conns_, nBaseChannels_); + setupMemorySemaphores(comm, this->conns_, nBaseChannels); // setup base memory channels - this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels_); + this->baseChannels_ = setupBaseMemoryChannels(this->conns_, memorySemaphores, nBaseChannels); this->memoryChannelsDeviceHandle_ = setupBaseMemoryChannelDeviceHandles(this->baseChannels_); this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_); } @@ -192,7 +191,7 @@ std::shared_ptr AllreduceNvlsWarpPipeline::initAllreduceContext(std::share // setup channels ctx->switchChannels = - setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_); + setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_); ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels); return ctx; } diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index 32fc6142..481e8ad8 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include -#include #include "allreduce/allreduce_nvls_zero_copy.hpp" #include "allreduce/common.hpp" @@ -195,11 +194,12 @@ std::shared_ptr AllreduceNvls::initAllreduceContext(std::shared_ptrswitchChannels = setupNvlsChannels(this->nvlsConnections_, (void*)sendBasePtr, sendBytes, nSwitchChannels_); + ctx->switchChannels = + setupNvlsChannels(comm, this->nvlsConnections_, (void*)sendBasePtr, sendBytes, nSwitchChannels_); if (input != output) { auto nvlsOutConnections = this->nvlsOutConnections_; std::vector outChannels = - setupNvlsChannels(this->nvlsOutConnections_, (void*)recvBasePtr, recvBytes, nSwitchChannels_); + setupNvlsChannels(comm, this->nvlsOutConnections_, (void*)recvBasePtr, recvBytes, nSwitchChannels_); ctx->switchChannels.insert(ctx->switchChannels.end(), outChannels.begin(), outChannels.end()); } diff --git a/src/ext/collectives/allreduce/allreduce_rsag.cu b/src/ext/collectives/allreduce/allreduce_rsag.cu index f07e0e2c..6fffc4da 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag.cu @@ -144,7 +144,7 @@ struct AllreduceRsAgAdapter { void AllreduceRsAg::initialize(std::shared_ptr comm) { this->conns_ = setupConnections(comm); - nChannelsPerConnection_ = 128; + nChannelsPerConnection_ = 64; comm_ = comm; // setup semaphores this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_); diff --git a/src/ext/collectives/collective_utils.cc b/src/ext/collectives/collective_utils.cc index c3856a88..5d038afa 100644 --- a/src/ext/collectives/collective_utils.cc +++ b/src/ext/collectives/collective_utils.cc @@ -6,12 +6,9 @@ #include #include #include -#include #include #include -#include "logger.hpp" - namespace mscclpp { namespace collective { std::vector setupRemoteMemories(std::shared_ptr comm, int rank, @@ -101,7 +98,8 @@ std::vector> setupNvlsConnections(std:: return nvlsConnections; } -std::vector setupNvlsChannels(std::vector> conns, +std::vector setupNvlsChannels(std::shared_ptr comm, + std::vector> conns, void* buffer, size_t bufferSize, int nSwitchChannels) { std::vector channels; @@ -110,6 +108,8 @@ std::vector setupNvlsChannels(std::vectorbindAllocatedMemory((CUdeviceptr)buffer, bufferSize); channels.push_back(switchChannel); } + // Synchronize to make sure all ranks have their NVLS channels set up before any rank starts using them. + comm->bootstrap()->barrier(); return channels; } diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp index 5662d116..81b74add 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp @@ -29,8 +29,6 @@ class AllreduceNvlsBlockPipeline : public AlgorithmBuilder { void* scratchBuffer_; size_t scratchBufferSize_; uint32_t nSwitchChannels_; - int nRanksPerIpcDomain_ = 0; - int nBaseChannels_ = 0; std::shared_ptr> memoryChannelsDeviceHandle_; std::vector baseChannels_; std::vector conns_; diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp index f347c871..8f02a873 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp @@ -29,8 +29,6 @@ class AllreduceNvlsWarpPipeline : public AlgorithmBuilder { void* scratchBuffer_; size_t scratchBufferSize_; uint32_t nSwitchChannels_; - int nRanksPerIpcDomain_ = 0; - int nBaseChannels_ = 0; std::shared_ptr> memoryChannelsDeviceHandle_; std::vector baseChannels_; std::vector conns_; diff --git a/src/ext/collectives/include/collective_utils.hpp b/src/ext/collectives/include/collective_utils.hpp index 2e61b937..95ce7f5a 100644 --- a/src/ext/collectives/include/collective_utils.hpp +++ b/src/ext/collectives/include/collective_utils.hpp @@ -57,7 +57,8 @@ std::shared_ptr> setupMemoryChannelDeviceHandles( std::vector> setupNvlsConnections(std::shared_ptr comm, size_t size, int numConnections); -std::vector setupNvlsChannels(std::vector> conns, void* buffer, +std::vector setupNvlsChannels(std::shared_ptr comm, + std::vector> conns, void* buffer, size_t bufferSize, int nSwitchChannels); std::shared_ptr> setupNvlsChannelDeviceHandles(