Shorten verbose comments and use THROW in validateIpcDomainSpansWorld

- Collapse the duplicated 3-line warp-strided-load comment in 5 kernels
  (allgather_fullmesh, allreduce_fullmesh, allreduce_packet,
  allreduce_nvls_zero_copy, allreduce_nvls_warp_pipeline) into a single
  one-line 'Peer count may exceed WARP_SIZE on MNNVL.' note.
- Drop the algName parameter from validateIpcDomainSpansWorld; switch
  its 3 throws to use the THROW logger macro (LogSubsys::ALGO), which
  already captures file/line/function. Update the 3 callsites
  (nvls_block_pipeline, nvls_warp_pipeline, nvls_zero_copy) and trim the
  Doxygen comment accordingly.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-06 21:37:09 +00:00
parent 4a0d5b29d5
commit 307a471888
8 changed files with 21 additions and 35 deletions

View File

@@ -30,9 +30,7 @@ __global__ void __launch_bounds__(1024, 1)
__shared__ DeviceHandle<MemoryChannel> channels[MAX_IPC_DOMAIN_NRANKS - 1];
const int lid = threadIdx.x % WARP_SIZE;
// Each warp redundantly loads all entries (same value, benign race) so that
// every warp has the data its threads will read after __syncwarp(). Required
// when nPeer > WARP_SIZE (MNNVL/NVL72 scale).
// Peer count may exceed WARP_SIZE on MNNVL.
for (int i = lid; i < nPeer; i += WARP_SIZE) {
channels[i] = memoryChans[i];
}

View File

@@ -52,9 +52,7 @@ __global__ void __launch_bounds__(512, 1)
__shared__ DeviceHandle<MemoryChannel> channels[MAX_IPC_DOMAIN_NRANKS - 1];
__shared__ DeviceHandle<MemoryChannel> outChannels[MAX_IPC_DOMAIN_NRANKS - 1];
const int lid = threadIdx.x % WARP_SIZE;
// Each warp redundantly loads all entries (same value, benign race) so that
// every warp has the data its threads will read after __syncwarp(). Required
// when nPeer > WARP_SIZE (MNNVL/NVL72 scale).
// Peer count may exceed WARP_SIZE on MNNVL.
for (int i = lid; i < nPeer; i += WARP_SIZE) {
channels[i] = memoryChans[i];
outChannels[i] = memoryOutChans[i];

View File

@@ -178,7 +178,7 @@ struct NvlsBlockPipelineAdapter {
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = 8;
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm, "AllreduceNvlsBlockPipeline");
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm);
// Block-pipeline device-side semaphore indices grow as 6 * ipcDomainNranks (see kernel).
if (6 * ipcDomainNranks_ > NUM_SEMAPHORES) {
throw Error("AllreduceNvlsBlockPipeline: ipcDomainNranks " + std::to_string(ipcDomainNranks_) +

View File

@@ -59,9 +59,7 @@ __global__ void __launch_bounds__(1024, 1)
auto memoryChans = memoryChannels + chanOffset;
__shared__ DeviceHandle<BaseMemoryChannel> channels[(MAX_IPC_DOMAIN_NRANKS - 1) * 2];
const int lid = threadIdx.x % WARP_SIZE;
// Each warp redundantly loads all entries (same value, benign race) so that
// every warp has the data its threads will read after __syncwarp(). Required
// when nPeers*2 > WARP_SIZE (MNNVL scale).
// Peer count may exceed WARP_SIZE on MNNVL.
for (int i = lid; i < nPeers * 2; i += WARP_SIZE) {
channels[i] = memoryChans[i];
}
@@ -144,7 +142,7 @@ struct NvlsWarpPipelineAdapter {
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
nSwitchChannels_ = NUM_NVLS_CONNECTION;
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm, "AllreduceNvlsWarpPipeline");
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm);
// The warp-pipeline kernel addresses 2 * nPeers entries per block in `memoryChannels`,
// so per-peer base channel allocation must be at least `2 * nBlocks`. Default
// nBlocks = 4 * ipcDomainNranks (see allreduceKernelFunc), so size accordingly.

View File

@@ -45,9 +45,7 @@ __global__ void __launch_bounds__(1024, 1)
auto memoryChans = memoryChannels + chanOffset;
__shared__ mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel> channels[MAX_IPC_DOMAIN_NRANKS - 1];
const int lid = threadIdx.x % WARP_SIZE;
// Each warp redundantly loads all entries (same value, benign race) so that
// every warp has the data its threads will read after __syncwarp(). Required
// when nPeers > WARP_SIZE (MNNVL/NVL72 → 71 peers).
// Peer count may exceed WARP_SIZE on MNNVL.
for (int i = lid; i < ipcDomainNranks - 1; i += WARP_SIZE) {
channels[i] = memoryChans[i];
}
@@ -107,7 +105,7 @@ void AllreduceNvls::initialize(std::shared_ptr<mscclpp::Communicator> comm) {
MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&deviceProp, device));
computeCapabilityMajor_ = deviceProp.major;
nSwitchChannels_ = 32;
validateIpcDomainSpansWorld(comm, "AllreduceNvls");
validateIpcDomainSpansWorld(comm);
this->conns_ = setupConnections(comm);
// setup semaphores
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores =

View File

@@ -80,9 +80,7 @@ __global__ void __launch_bounds__(1024, 1)
// Put channels into shared memory, read channel info from global memory is unexpectable slow.
__shared__ mscclpp::DeviceHandle<mscclpp::MemoryChannel> channels[MAX_IPC_DOMAIN_NRANKS - 1];
const int lid = tid % WARP_SIZE;
// Each warp redundantly loads all entries (same value, benign race) so that
// every warp has the data its threads will read after __syncwarp(). Required
// when nPeers > WARP_SIZE (MNNVL/NVL72 scale).
// Peer count may exceed WARP_SIZE on MNNVL.
for (int i = lid; i < nPeers; i += WARP_SIZE) {
channels[i] = memoryChannels[i];
}

View File

@@ -11,6 +11,8 @@
#include <mscclpp/memory_channel.hpp>
#include <mscclpp/switch_channel.hpp>
#include "logger.hpp"
namespace mscclpp {
namespace collective {
std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<mscclpp::Communicator> comm, int rank,
@@ -79,24 +81,22 @@ int getIpcDomainNranks(std::shared_ptr<mscclpp::Communicator> comm) {
return comm->bootstrap()->getNranksPerNode();
}
int validateIpcDomainSpansWorld(std::shared_ptr<mscclpp::Communicator> comm, const char* algName) {
int validateIpcDomainSpansWorld(std::shared_ptr<mscclpp::Communicator> comm) {
const int ipcDomainNranks = getIpcDomainNranks(comm);
const int worldSize = comm->bootstrap()->getNranks();
const int rank = comm->bootstrap()->getRank();
if (ipcDomainNranks < 2 || ipcDomainNranks > MAX_IPC_DOMAIN_NRANKS) {
throw mscclpp::Error(std::string(algName) + ": ipcDomainNranks " + std::to_string(ipcDomainNranks) +
" is out of supported range [2, " + std::to_string(MAX_IPC_DOMAIN_NRANKS) + "]",
mscclpp::ErrorCode::InvalidUsage);
THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, "ipcDomainNranks ",
ipcDomainNranks, " is out of supported range [2, ", MAX_IPC_DOMAIN_NRANKS, "]");
}
if (worldSize != ipcDomainNranks) {
throw mscclpp::Error(std::string(algName) + " requires worldSize == ipcDomainNranks (got worldSize=" +
std::to_string(worldSize) + ", ipcDomainNranks=" + std::to_string(ipcDomainNranks) + ")",
mscclpp::ErrorCode::InvalidUsage);
THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
"requires worldSize == ipcDomainNranks (got worldSize=", worldSize, ", ipcDomainNranks=", ipcDomainNranks,
")");
}
if (rank < 0 || rank >= ipcDomainNranks) {
throw mscclpp::Error(std::string(algName) + ": rank " + std::to_string(rank) + " out of [0, " +
std::to_string(ipcDomainNranks) + ")",
mscclpp::ErrorCode::InvalidUsage);
THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, "rank ", rank, " out of [0, ",
ipcDomainNranks, ")");
}
return ipcDomainNranks;
}

View File

@@ -60,13 +60,9 @@ int getIpcDomainNranks(std::shared_ptr<Communicator> comm);
/// Validates that the IPC domain spans the whole communicator and that the local rank fits within
/// the supported `[2, MAX_IPC_DOMAIN_NRANKS]` range. Used by NVLS allreduce algorithms whose
/// multicast group spans the whole communicator (see `setupNvlsConnections`) and whose kernels
/// use the global rank to compute per-rank offsets while sizing per-rank work by
/// `ipcDomainNranks`. These assumptions only hold when the IPC-reachable peer group is exactly
/// the whole communicator (e.g. a fully populated MNNVL fabric like NVL72). Returns the validated
/// `ipcDomainNranks`; throws `Error(InvalidUsage)` on violation. `algName` is used as a prefix
/// in error messages.
int validateIpcDomainSpansWorld(std::shared_ptr<Communicator> comm, const char* algName);
/// multicast group spans the whole communicator. Returns the validated `ipcDomainNranks`; throws
/// `Error(InvalidUsage)` on violation.
int validateIpcDomainSpansWorld(std::shared_ptr<Communicator> comm);
std::shared_ptr<DeviceHandle<MemoryChannel>> setupMemoryChannelDeviceHandles(
const std::vector<MemoryChannel>& memoryChannels);