Add check to filter invalid nblock/nthread candidates (#811)

Add check for invalid nblock/nthread candidate
This commit is contained in:
Binyang Li
2026-05-22 09:18:41 -07:00
committed by GitHub
parent 9e177b388c
commit 08ee18be64
3 changed files with 46 additions and 14 deletions

View File

@@ -70,12 +70,12 @@ class CustomizedComm:
_TUNE_N_WARMUP = 5
_TUNE_N_GRAPH_LAUNCHES = 10
_TUNE_N_OPS_PER_GRAPH = 100
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 128]
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 56, 64, 128]
_CANDIDATE_NTHREADS = [512, 768, 1024]
_NBLOCKS_LIMIT = {
"default_allreduce_nvls_packet": 16,
"default_allreduce_packet": 56,
"default_allreduce_allpair_packet": 56,
"default_allreduce_allpair_packet": 64,
"default_allreduce_fullmesh": 64,
"default_allgather_fullmesh2": 32,
}

View File

@@ -8,6 +8,11 @@
namespace mscclpp {
namespace collective {
namespace {
constexpr int kMaxBlocks = 56;
constexpr int kMaxThreadsPerBlock = 1024;
} // namespace
template <bool IsOutOfPlace>
__global__ void __launch_bounds__(1024, 1)
allgatherFullmesh(void* buff, void* scratch, void* resultBuff, DeviceHandle<MemoryChannel>* memoryChannels,
@@ -116,12 +121,19 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr<void> ct
int rank = ctx->rank;
const size_t nElem = inputSize / sizeof(int);
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
if (numBlocksAndThreads.first > 56) {
WARN("AllgatherFullmesh: number of blocks exceeds maximum supported blocks, which is 56");
return mscclpp::CommResult::CommInvalidArgument;
}
if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0) {
numBlocksAndThreads = {56, 1024};
numBlocksAndThreads = {kMaxBlocks, kMaxThreadsPerBlock};
}
if (numBlocksAndThreads.first > kMaxBlocks || numBlocksAndThreads.second > kMaxThreadsPerBlock) {
WARN(
"AllgatherFullmesh: number of blocks must be no more than %d and threads per block must be no more than %d; "
"got nBlocks=%d, nThreadsPerBlock=%d",
kMaxBlocks, kMaxThreadsPerBlock, numBlocksAndThreads.first, numBlocksAndThreads.second);
return CommResult::CommInvalidArgument;
}
if (numBlocksAndThreads.second % WARP_SIZE != 0) {
WARN("AllgatherFullmesh: threads per block must be a multiple of warp size %d", WARP_SIZE);
return CommResult::CommInvalidArgument;
}
if ((char*)input == (char*)output + rank * inputSize) {
allgatherFullmesh<false><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
@@ -142,15 +154,13 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr<void> ct
std::shared_ptr<void> AllgatherFullmesh::initAllgatherContext(std::shared_ptr<Communicator> comm, const void* input,
void*, size_t inputSize, DataType) {
constexpr int nChannelsPerConnection = 56;
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// setup semaphores
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, kMaxBlocks);
// register the memory for the broadcast operation
RegisteredMemory localMemory = comm->registerMemory((void*)input, inputSize, Transport::CudaIpc);
@@ -159,7 +169,7 @@ std::shared_ptr<void> AllgatherFullmesh::initAllgatherContext(std::shared_ptr<Co
// setup channels
ctx->memoryChannels =
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection);
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, kMaxBlocks);
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
// keep registered memories reference
@@ -196,4 +206,4 @@ std::shared_ptr<Algorithm> AllgatherFullmesh::build() {
});
}
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp

View File

@@ -18,7 +18,11 @@ __global__ void __launch_bounds__(1024, 1)
const size_t lid = tid % WARP_SIZE;
const size_t wid = tid / WARP_SIZE;
const size_t nThread = blockDim.x * gridDim.x;
// Round down to multiple of warp size
const size_t nThread = (blockDim.x * gridDim.x) / WARP_SIZE * WARP_SIZE;
if (tid >= nThread) {
return;
}
const size_t nWarp = nThread / WARP_SIZE;
const size_t nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;
@@ -135,6 +139,24 @@ CommResult AllgatherFullmesh2::allgatherKernelFunc(const std::shared_ptr<void> c
numBlocksAndThreads.first = 35;
}
}
const int nPeer = ctx->nRanksPerNode - 1;
const int nWarp = numBlocksAndThreads.first * numBlocksAndThreads.second / WARP_SIZE;
if (numBlocksAndThreads.first > nChannelsPerConnection_ || numBlocksAndThreads.first <= 0 ||
numBlocksAndThreads.second <= 0) {
WARN(
"AllgatherFullmesh2: number of blocks must be a positive multiple of peer count and no more than %d, threads "
"per block must be positive; got nBlocks=%d, nThreadsPerBlock=%d, nPeers=%d",
nChannelsPerConnection_, numBlocksAndThreads.first, numBlocksAndThreads.second, nPeer);
return CommResult::CommInvalidArgument;
}
if (nWarp < nPeer) {
WARN(
"AllgatherFullmesh2: total number of warps must be no less than peer count; got nBlocks=%d, "
"nThreadsPerBlock=%d, "
"nPeers=%d",
numBlocksAndThreads.first, numBlocksAndThreads.second, nPeer);
return CommResult::CommInvalidArgument;
}
size_t channelOutOffset = *static_cast<size_t*>(ctx->extras["channel_out_offset"].get());
if ((char*)input == (char*)output + rank * inputSize) {
@@ -226,4 +248,4 @@ std::shared_ptr<Algorithm> AllgatherFullmesh2::build() {
}
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp