mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
Shorten verbose comments and use THROW in validateIpcDomainSpansWorld
- Collapse the duplicated 3-line warp-strided-load comment in 5 kernels (allgather_fullmesh, allreduce_fullmesh, allreduce_packet, allreduce_nvls_zero_copy, allreduce_nvls_warp_pipeline) into a single one-line 'Peer count may exceed WARP_SIZE on MNNVL.' note. - Drop the algName parameter from validateIpcDomainSpansWorld; switch its 3 throws to use the THROW logger macro (LogSubsys::ALGO), which already captures file/line/function. Update the 3 callsites (nvls_block_pipeline, nvls_warp_pipeline, nvls_zero_copy) and trim the Doxygen comment accordingly. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -30,9 +30,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
|
||||
__shared__ DeviceHandle<MemoryChannel> channels[MAX_IPC_DOMAIN_NRANKS - 1];
|
||||
const int lid = threadIdx.x % WARP_SIZE;
|
||||
// Each warp redundantly loads all entries (same value, benign race) so that
|
||||
// every warp has the data its threads will read after __syncwarp(). Required
|
||||
// when nPeer > WARP_SIZE (MNNVL/NVL72 scale).
|
||||
// Peer count may exceed WARP_SIZE on MNNVL.
|
||||
for (int i = lid; i < nPeer; i += WARP_SIZE) {
|
||||
channels[i] = memoryChans[i];
|
||||
}
|
||||
|
||||
@@ -52,9 +52,7 @@ __global__ void __launch_bounds__(512, 1)
|
||||
__shared__ DeviceHandle<MemoryChannel> channels[MAX_IPC_DOMAIN_NRANKS - 1];
|
||||
__shared__ DeviceHandle<MemoryChannel> outChannels[MAX_IPC_DOMAIN_NRANKS - 1];
|
||||
const int lid = threadIdx.x % WARP_SIZE;
|
||||
// Each warp redundantly loads all entries (same value, benign race) so that
|
||||
// every warp has the data its threads will read after __syncwarp(). Required
|
||||
// when nPeer > WARP_SIZE (MNNVL/NVL72 scale).
|
||||
// Peer count may exceed WARP_SIZE on MNNVL.
|
||||
for (int i = lid; i < nPeer; i += WARP_SIZE) {
|
||||
channels[i] = memoryChans[i];
|
||||
outChannels[i] = memoryOutChans[i];
|
||||
|
||||
@@ -178,7 +178,7 @@ struct NvlsBlockPipelineAdapter {
|
||||
|
||||
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
nSwitchChannels_ = 8;
|
||||
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm, "AllreduceNvlsBlockPipeline");
|
||||
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm);
|
||||
// Block-pipeline device-side semaphore indices grow as 6 * ipcDomainNranks (see kernel).
|
||||
if (6 * ipcDomainNranks_ > NUM_SEMAPHORES) {
|
||||
throw Error("AllreduceNvlsBlockPipeline: ipcDomainNranks " + std::to_string(ipcDomainNranks_) +
|
||||
|
||||
@@ -59,9 +59,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
auto memoryChans = memoryChannels + chanOffset;
|
||||
__shared__ DeviceHandle<BaseMemoryChannel> channels[(MAX_IPC_DOMAIN_NRANKS - 1) * 2];
|
||||
const int lid = threadIdx.x % WARP_SIZE;
|
||||
// Each warp redundantly loads all entries (same value, benign race) so that
|
||||
// every warp has the data its threads will read after __syncwarp(). Required
|
||||
// when nPeers*2 > WARP_SIZE (MNNVL scale).
|
||||
// Peer count may exceed WARP_SIZE on MNNVL.
|
||||
for (int i = lid; i < nPeers * 2; i += WARP_SIZE) {
|
||||
channels[i] = memoryChans[i];
|
||||
}
|
||||
@@ -144,7 +142,7 @@ struct NvlsWarpPipelineAdapter {
|
||||
|
||||
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
|
||||
nSwitchChannels_ = NUM_NVLS_CONNECTION;
|
||||
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm, "AllreduceNvlsWarpPipeline");
|
||||
ipcDomainNranks_ = validateIpcDomainSpansWorld(comm);
|
||||
// The warp-pipeline kernel addresses 2 * nPeers entries per block in `memoryChannels`,
|
||||
// so per-peer base channel allocation must be at least `2 * nBlocks`. Default
|
||||
// nBlocks = 4 * ipcDomainNranks (see allreduceKernelFunc), so size accordingly.
|
||||
|
||||
@@ -45,9 +45,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
auto memoryChans = memoryChannels + chanOffset;
|
||||
__shared__ mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel> channels[MAX_IPC_DOMAIN_NRANKS - 1];
|
||||
const int lid = threadIdx.x % WARP_SIZE;
|
||||
// Each warp redundantly loads all entries (same value, benign race) so that
|
||||
// every warp has the data its threads will read after __syncwarp(). Required
|
||||
// when nPeers > WARP_SIZE (MNNVL/NVL72 → 71 peers).
|
||||
// Peer count may exceed WARP_SIZE on MNNVL.
|
||||
for (int i = lid; i < ipcDomainNranks - 1; i += WARP_SIZE) {
|
||||
channels[i] = memoryChans[i];
|
||||
}
|
||||
@@ -107,7 +105,7 @@ void AllreduceNvls::initialize(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&deviceProp, device));
|
||||
computeCapabilityMajor_ = deviceProp.major;
|
||||
nSwitchChannels_ = 32;
|
||||
validateIpcDomainSpansWorld(comm, "AllreduceNvls");
|
||||
validateIpcDomainSpansWorld(comm);
|
||||
this->conns_ = setupConnections(comm);
|
||||
// setup semaphores
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores =
|
||||
|
||||
@@ -80,9 +80,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
// Put channels into shared memory, read channel info from global memory is unexpectable slow.
|
||||
__shared__ mscclpp::DeviceHandle<mscclpp::MemoryChannel> channels[MAX_IPC_DOMAIN_NRANKS - 1];
|
||||
const int lid = tid % WARP_SIZE;
|
||||
// Each warp redundantly loads all entries (same value, benign race) so that
|
||||
// every warp has the data its threads will read after __syncwarp(). Required
|
||||
// when nPeers > WARP_SIZE (MNNVL/NVL72 scale).
|
||||
// Peer count may exceed WARP_SIZE on MNNVL.
|
||||
for (int i = lid; i < nPeers; i += WARP_SIZE) {
|
||||
channels[i] = memoryChannels[i];
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/switch_channel.hpp>
|
||||
|
||||
#include "logger.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<mscclpp::Communicator> comm, int rank,
|
||||
@@ -79,24 +81,22 @@ int getIpcDomainNranks(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
return comm->bootstrap()->getNranksPerNode();
|
||||
}
|
||||
|
||||
int validateIpcDomainSpansWorld(std::shared_ptr<mscclpp::Communicator> comm, const char* algName) {
|
||||
int validateIpcDomainSpansWorld(std::shared_ptr<mscclpp::Communicator> comm) {
|
||||
const int ipcDomainNranks = getIpcDomainNranks(comm);
|
||||
const int worldSize = comm->bootstrap()->getNranks();
|
||||
const int rank = comm->bootstrap()->getRank();
|
||||
if (ipcDomainNranks < 2 || ipcDomainNranks > MAX_IPC_DOMAIN_NRANKS) {
|
||||
throw mscclpp::Error(std::string(algName) + ": ipcDomainNranks " + std::to_string(ipcDomainNranks) +
|
||||
" is out of supported range [2, " + std::to_string(MAX_IPC_DOMAIN_NRANKS) + "]",
|
||||
mscclpp::ErrorCode::InvalidUsage);
|
||||
THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, "ipcDomainNranks ",
|
||||
ipcDomainNranks, " is out of supported range [2, ", MAX_IPC_DOMAIN_NRANKS, "]");
|
||||
}
|
||||
if (worldSize != ipcDomainNranks) {
|
||||
throw mscclpp::Error(std::string(algName) + " requires worldSize == ipcDomainNranks (got worldSize=" +
|
||||
std::to_string(worldSize) + ", ipcDomainNranks=" + std::to_string(ipcDomainNranks) + ")",
|
||||
mscclpp::ErrorCode::InvalidUsage);
|
||||
THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage,
|
||||
"requires worldSize == ipcDomainNranks (got worldSize=", worldSize, ", ipcDomainNranks=", ipcDomainNranks,
|
||||
")");
|
||||
}
|
||||
if (rank < 0 || rank >= ipcDomainNranks) {
|
||||
throw mscclpp::Error(std::string(algName) + ": rank " + std::to_string(rank) + " out of [0, " +
|
||||
std::to_string(ipcDomainNranks) + ")",
|
||||
mscclpp::ErrorCode::InvalidUsage);
|
||||
THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, "rank ", rank, " out of [0, ",
|
||||
ipcDomainNranks, ")");
|
||||
}
|
||||
return ipcDomainNranks;
|
||||
}
|
||||
|
||||
@@ -60,13 +60,9 @@ int getIpcDomainNranks(std::shared_ptr<Communicator> comm);
|
||||
|
||||
/// Validates that the IPC domain spans the whole communicator and that the local rank fits within
|
||||
/// the supported `[2, MAX_IPC_DOMAIN_NRANKS]` range. Used by NVLS allreduce algorithms whose
|
||||
/// multicast group spans the whole communicator (see `setupNvlsConnections`) and whose kernels
|
||||
/// use the global rank to compute per-rank offsets while sizing per-rank work by
|
||||
/// `ipcDomainNranks`. These assumptions only hold when the IPC-reachable peer group is exactly
|
||||
/// the whole communicator (e.g. a fully populated MNNVL fabric like NVL72). Returns the validated
|
||||
/// `ipcDomainNranks`; throws `Error(InvalidUsage)` on violation. `algName` is used as a prefix
|
||||
/// in error messages.
|
||||
int validateIpcDomainSpansWorld(std::shared_ptr<Communicator> comm, const char* algName);
|
||||
/// multicast group spans the whole communicator. Returns the validated `ipcDomainNranks`; throws
|
||||
/// `Error(InvalidUsage)` on violation.
|
||||
int validateIpcDomainSpansWorld(std::shared_ptr<Communicator> comm);
|
||||
|
||||
std::shared_ptr<DeviceHandle<MemoryChannel>> setupMemoryChannelDeviceHandles(
|
||||
const std::vector<MemoryChannel>& memoryChannels);
|
||||
|
||||
Reference in New Issue
Block a user