Strip preflight validation blocks from NVLS pipeline allreduce kernels

allreduce_nvls_block_pipeline.cu and allreduce_nvls_warp_pipeline.cu
were carrying ~45 lines of per-call invariant-checking added during the
MNNVL work. Restore main's simple defaulting pattern (just `if
(==0) set defaults`); incorrect inputs will manifest as CUDA errors via
the existing error-handling path. Also drop the unreachable
`6 * ipcDomainNranks > NUM_SEMAPHORES` throw in the block_pipeline
initialize (max ipcDomainNranks=72, NUM_SEMAPHORES=512), the now-unused
`<mscclpp/errors.hpp>` include, and trim the verbose comments around
`nBaseChannels_` sizing in both files.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-06 23:04:41 +00:00
parent 639b80de7b
commit e8caab7c8e
2 changed files with 6 additions and 101 deletions

View File

@@ -3,7 +3,6 @@
#include <algorithm>
#include <mscclpp/algorithm.hpp>
#include <mscclpp/errors.hpp>
#include "allreduce/allreduce_nvls_block_pipeline.hpp"
#include "allreduce/common.hpp"
@@ -179,14 +178,7 @@ struct NvlsBlockPipelineAdapter {
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = 8;
ipcDomainNranks_ = getIpcDomainNranks(comm);
// Block-pipeline device-side semaphore indices grow as 6 * ipcDomainNranks (see kernel).
if (6 * ipcDomainNranks_ > NUM_SEMAPHORES) {
throw Error("AllreduceNvlsBlockPipeline: ipcDomainNranks " + std::to_string(ipcDomainNranks_) +
" exceeds NUM_SEMAPHORES capacity (" + std::to_string(NUM_SEMAPHORES) + ")",
ErrorCode::InvalidUsage);
}
// The kernel addresses up to `2 * nBlocksForCopy = 4 * ipcDomainNranks` distinct entries
// per peer in `memoryChannels`. Scale the per-connection allocation to match.
// Per-peer channel allocation must hold up to 4 * ipcDomainNranks entries (see kernel).
nBaseChannels_ = std::max(64, 4 * ipcDomainNranks_);
this->conns_ = setupConnections(comm);
// setup semaphores
@@ -208,43 +200,9 @@ CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
}
const int requiredBlocks = ctx->ipcDomainNranks * 5;
std::pair<int, int> blockAndThreadNum = {nBlocks, nThreadsPerBlock};
if (blockAndThreadNum.first == 0) blockAndThreadNum.first = requiredBlocks;
if (blockAndThreadNum.second == 0) blockAndThreadNum.second = 1024;
if (blockAndThreadNum.first != requiredBlocks) {
WARN("AllreduceNvlsBlockPipeline requires nBlocks == 5 * ipcDomainNranks (got %d, expected %d)",
blockAndThreadNum.first, requiredBlocks);
return CommResult::CommInvalidArgument;
}
if (blockAndThreadNum.second != 1024) {
WARN("AllreduceNvlsBlockPipeline requires nThreadsPerBlock == 1024 (got %d)", blockAndThreadNum.second);
return CommResult::CommInvalidArgument;
}
// Validate input alignment/divisibility expectations of the kernel.
constexpr size_t kKernelAlignment = 16;
const size_t perRankBytes = inputSize / ctx->ipcDomainNranks;
if (perRankBytes * static_cast<size_t>(ctx->ipcDomainNranks) != inputSize || perRankBytes % kKernelAlignment != 0) {
WARN(
"AllreduceNvlsBlockPipeline requires inputSize %% (ipcDomainNranks * %zu) == 0 (got inputSize=%zu, "
"ipcDomainNranks=%d)",
kKernelAlignment, inputSize, ctx->ipcDomainNranks);
return CommResult::CommInvalidArgument;
}
// Validate scratch is large enough for at least one pipeline iteration. The kernel
// computes scratchSizePerBlock = (scratchSizePerRank / nBlocksForCopy) aligned down
// to unitSize; if this is 0, maxItersForScratch is 0 and the kernel deadlocks.
const size_t unitSize = (inputSize <= static_cast<size_t>(1024) * 1024 * 128) ? (1ULL << 16) : (1ULL << 17);
const size_t scratchSizePerRank = this->scratchBufferSize_ / ctx->ipcDomainNranks;
const size_t nBlocksForCopy = static_cast<size_t>(ctx->ipcDomainNranks) * 2;
const size_t scratchSizePerBlock = (scratchSizePerRank / nBlocksForCopy) / unitSize * unitSize;
if (scratchSizePerBlock < unitSize) {
WARN(
"AllreduceNvlsBlockPipeline scratch buffer too small for ipcDomainNranks=%d and inputSize=%zu "
"(scratchBufferSize=%zu, need at least ~%zu bytes)",
ctx->ipcDomainNranks, inputSize, this->scratchBufferSize_,
static_cast<size_t>(ctx->ipcDomainNranks) * nBlocksForCopy * unitSize);
return CommResult::CommInvalidArgument;
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
blockAndThreadNum = {ctx->ipcDomainNranks * 5, 1024};
}
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr,
ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_,

View File

@@ -3,7 +3,6 @@
#include <algorithm>
#include <mscclpp/algorithm.hpp>
#include <mscclpp/errors.hpp>
#include "allreduce/allreduce_nvls_warp_pipeline.hpp"
#include "allreduce/common.hpp"
@@ -143,9 +142,7 @@ struct NvlsWarpPipelineAdapter {
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = NUM_NVLS_CONNECTION;
ipcDomainNranks_ = getIpcDomainNranks(comm);
// The warp-pipeline kernel addresses 2 * nPeers entries per block in `memoryChannels`,
// so per-peer base channel allocation must be at least `2 * nBlocks`. Default
// nBlocks = 4 * ipcDomainNranks (see allreduceKernelFunc), so size accordingly.
// Per-peer channel allocation must hold 2 * nBlocks entries; default nBlocks = 4 * ipcDomainNranks.
nBaseChannels_ = std::max(64, 8 * ipcDomainNranks_);
this->conns_ = setupConnections(comm);
// setup semaphores
@@ -168,58 +165,8 @@ CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(
return CommResult::CommInvalidArgument;
}
std::pair<int, int> blockAndThreadNum = {nBlocks, nThreadsPerBlock};
if (blockAndThreadNum.first == 0) {
// Default to 4 * ipcDomainNranks blocks, rounded up to a multiple of NUM_NVLS_CONNECTION
// so that nBlocks / NUM_NVLS_CONNECTION partitioning in the kernel is well-defined.
int defaultBlocks = ctx->ipcDomainNranks * 4;
defaultBlocks = ((defaultBlocks + NUM_NVLS_CONNECTION - 1) / NUM_NVLS_CONNECTION) * NUM_NVLS_CONNECTION;
blockAndThreadNum.first = std::max(defaultBlocks, NUM_NVLS_CONNECTION);
}
if (blockAndThreadNum.second == 0) blockAndThreadNum.second = 1024;
// The kernel computes nBlocksPerNvlsConn = nBlocks / NUM_NVLS_CONNECTION and indexes the
// multicast handle array with bid / nBlocksPerNvlsConn; both must be safe.
if (blockAndThreadNum.first < NUM_NVLS_CONNECTION || blockAndThreadNum.first % NUM_NVLS_CONNECTION != 0) {
WARN("AllreduceNvlsWarpPipeline requires nBlocks to be a positive multiple of %d (got %d)", NUM_NVLS_CONNECTION,
blockAndThreadNum.first);
return CommResult::CommInvalidArgument;
}
// Each block uses 2 * nPeers consecutive entries in `memoryChannels`, so the per-peer
// base-channel allocation must support 2 * nBlocks distinct entries.
if (2 * blockAndThreadNum.first > this->nBaseChannels_) {
WARN(
"AllreduceNvlsWarpPipeline: nBlocks %d exceeds channel allocation (nBaseChannels=%d, "
"ipcDomainNranks=%d). Increase MSCCLPP_IPC_DOMAIN_NRANKS-aware sizing or reduce nBlocks.",
blockAndThreadNum.first, this->nBaseChannels_, ctx->ipcDomainNranks);
return CommResult::CommInvalidArgument;
}
// The kernel hard-codes 14 + 4 + 14 = 32 warps per block and bar.sync member counts
// computed from these constants; deviating from 1024 threads breaks those barriers.
if (blockAndThreadNum.second != 1024) {
WARN("AllreduceNvlsWarpPipeline requires nThreadsPerBlock == 1024 (got %d)", blockAndThreadNum.second);
return CommResult::CommInvalidArgument;
}
// Validate input divisibility by ipcDomainNranks (kernel computes size / ipcDomainNranks).
if (inputSize % static_cast<size_t>(ctx->ipcDomainNranks) != 0) {
WARN("AllreduceNvlsWarpPipeline requires inputSize %% ipcDomainNranks == 0 (got inputSize=%zu, ipcDomainNranks=%d)",
inputSize, ctx->ipcDomainNranks);
return CommResult::CommInvalidArgument;
}
// Validate scratch is large enough for at least one pipeline iteration. The kernel
// computes scratchSizePerBlock = (scratchSizePerRank / nBlocks) aligned down to copyPerIter;
// if this is 0 the modulo offset arithmetic divides by zero.
const size_t sizePerRank = inputSize / static_cast<size_t>(ctx->ipcDomainNranks);
const size_t maxSizePerBlock = ((sizePerRank + blockAndThreadNum.first - 1) / blockAndThreadNum.first + 15) / 16 * 16;
const size_t copyPerIter = (maxSizePerBlock >= 1024 * 64) ? (1024 * 32) : (1024 * 16);
const size_t scratchSizePerRank = this->scratchBufferSize_ / static_cast<size_t>(ctx->ipcDomainNranks);
const size_t scratchSizePerBlock =
(scratchSizePerRank / static_cast<size_t>(blockAndThreadNum.first)) / copyPerIter * copyPerIter;
if (scratchSizePerBlock < copyPerIter) {
WARN(
"AllreduceNvlsWarpPipeline scratch buffer too small for ipcDomainNranks=%d, nBlocks=%d, inputSize=%zu "
"(scratchBufferSize=%zu, need at least ~%zu bytes)",
ctx->ipcDomainNranks, blockAndThreadNum.first, inputSize, this->scratchBufferSize_,
static_cast<size_t>(ctx->ipcDomainNranks) * static_cast<size_t>(blockAndThreadNum.first) * copyPerIter);
return CommResult::CommInvalidArgument;
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
blockAndThreadNum = {ctx->ipcDomainNranks * 4, 1024};
}
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr,
ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_,