mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-24 23:06:17 +00:00
Fix collective topology sizing
Rename native collective context workSize to worldSize and use nRanksPerIpcDomain for allpair peer topology. Include the staged DSL signal/wait pairing validation changes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -79,7 +79,7 @@ __global__ void __launch_bounds__(1024)
|
||||
|
||||
struct Context {
|
||||
int rank;
|
||||
int workSize;
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
|
||||
std::vector<mscclpp::RegisteredMemory> registeredMemories;
|
||||
@@ -140,7 +140,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
size_t inputSize, cudaStream_t stream) {
|
||||
auto algoCtx = std::static_pointer_cast<Context>(ctx);
|
||||
int rank = algoCtx->rank;
|
||||
int worldSize = algoCtx->workSize;
|
||||
int worldSize = algoCtx->worldSize;
|
||||
|
||||
int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE;
|
||||
allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputSize);
|
||||
@@ -154,16 +154,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
void* output, size_t inputSize, mscclpp::DataType dtype) {
|
||||
auto ctx = std::make_shared<Context>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
// register memories
|
||||
mscclpp::RegisteredMemory inputBufRegMem =
|
||||
comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc);
|
||||
mscclpp::RegisteredMemory outputBufRegMem =
|
||||
comm->registerMemory(output, inputSize * ctx->workSize, mscclpp::Transport::CudaIpc);
|
||||
comm->registerMemory(output, inputSize * ctx->worldSize, mscclpp::Transport::CudaIpc);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
for (int i = 0; i < ctx->workSize; i++) {
|
||||
for (int i = 0; i < ctx->worldSize; i++) {
|
||||
if (i == ctx->rank) continue;
|
||||
comm->sendMemory(outputBufRegMem, i, 0);
|
||||
remoteRegMemories.push_back(comm->recvMemory(i, 0));
|
||||
|
||||
@@ -47,7 +47,7 @@ __global__ void __launch_bounds__(1024)
|
||||
|
||||
struct Context {
|
||||
int rank;
|
||||
int workSize;
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
|
||||
std::vector<mscclpp::RegisteredMemory> registeredMemories;
|
||||
@@ -108,7 +108,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
cudaStream_t stream) {
|
||||
auto algoCtx = std::static_pointer_cast<Context>(ctx);
|
||||
int rank = algoCtx->rank;
|
||||
int worldSize = algoCtx->workSize;
|
||||
int worldSize = algoCtx->worldSize;
|
||||
|
||||
int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE;
|
||||
allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputBytes);
|
||||
@@ -122,16 +122,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
void* output, size_t inputBytes, mscclpp::DataType dtype) {
|
||||
auto ctx = std::make_shared<Context>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
// register memories
|
||||
mscclpp::RegisteredMemory inputBufRegMem =
|
||||
comm->registerMemory((void*)input, inputBytes, mscclpp::Transport::CudaIpc);
|
||||
mscclpp::RegisteredMemory outputBufRegMem =
|
||||
comm->registerMemory(output, inputBytes * ctx->workSize, mscclpp::Transport::CudaIpc);
|
||||
comm->registerMemory(output, inputBytes * ctx->worldSize, mscclpp::Transport::CudaIpc);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
for (int i = 0; i < ctx->workSize; i++) {
|
||||
for (int i = 0; i < ctx->worldSize; i++) {
|
||||
if (i == ctx->rank) continue;
|
||||
comm->sendMemory(outputBufRegMem, i, 0);
|
||||
remoteRegMemories.push_back(comm->recvMemory(i, 0));
|
||||
|
||||
@@ -78,6 +78,7 @@ class MemoryChannel:
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False):
|
||||
"""Wait for a signal through the memory channel.
|
||||
@@ -99,6 +100,7 @@ class MemoryChannel:
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
|
||||
"""Retrieve data from remote memory to local memory.
|
||||
@@ -508,6 +510,7 @@ class PortChannel:
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def wait(self, tb: int, data_sync: SyncType = SyncType.both):
|
||||
"""Wait for a signal through the port channel.
|
||||
@@ -527,6 +530,7 @@ class PortChannel:
|
||||
tb_channel_ids = get_program().setup_channel(tb, self)
|
||||
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync)
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def flush(self, tb: int, data_sync: SyncType = SyncType.both):
|
||||
"""Flush pending operations through the port channel.
|
||||
@@ -636,6 +640,7 @@ class PortChannel:
|
||||
)
|
||||
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
|
||||
"""Send data from local memory to remote memory with signal and flush.
|
||||
@@ -681,6 +686,7 @@ class PortChannel:
|
||||
)
|
||||
|
||||
get_program().add_operation(self.src_rank, tb, op)
|
||||
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
|
||||
|
||||
def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
|
||||
"""Transfer data from local buffer to remote scratch buffer in packet format.
|
||||
|
||||
@@ -10,6 +10,7 @@ from mscclpp.language.rank import Semaphore
|
||||
from mscclpp.language.collectives import *
|
||||
from mscclpp.language.utils import AlgoSpec, ReplicationPolicy
|
||||
from typing import List
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
|
||||
@@ -112,6 +113,9 @@ class CollectiveProgram:
|
||||
|
||||
self.loop_context = None
|
||||
|
||||
self._signal_counts = defaultdict(int)
|
||||
self._wait_counts = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def from_spec(cls, spec: AlgoSpec):
|
||||
"""Initialize a new CollectiveProgram from an algorithm specification.
|
||||
@@ -206,7 +210,35 @@ class CollectiveProgram:
|
||||
else:
|
||||
self.gpus[rank].add_operation(tb, operation)
|
||||
|
||||
def register_signal(self, src_rank, dst_rank, channel_type):
|
||||
"""Record that `src_rank` issued a signal targeting `dst_rank` over `channel_type`."""
|
||||
self._signal_counts[(src_rank, dst_rank, channel_type)] += 1
|
||||
|
||||
def register_wait(self, src_rank, dst_rank, channel_type):
|
||||
"""Record that `src_rank` performed a wait for `dst_rank` over `channel_type`."""
|
||||
self._wait_counts[(src_rank, dst_rank, channel_type)] += 1
|
||||
|
||||
def validate_signal_wait_pairing(self):
|
||||
"""Validate that every signal issued by a rank is matched by a wait on the peer rank.
|
||||
|
||||
For each (src_rank, dst_rank, channel_type) triple, the number of signals sent by
|
||||
`src_rank` to `dst_rank` must equal the number of waits performed by `dst_rank`
|
||||
for `src_rank` on a channel of the same type. Raises RuntimeError on mismatch.
|
||||
"""
|
||||
keys = set(self._signal_counts.keys()) | {(dst, src, t) for (src, dst, t) in self._wait_counts.keys()}
|
||||
for src_rank, dst_rank, channel_type in keys:
|
||||
signals = self._signal_counts.get((src_rank, dst_rank, channel_type), 0)
|
||||
waits = self._wait_counts.get((dst_rank, src_rank, channel_type), 0)
|
||||
if signals != waits:
|
||||
raise RuntimeError(
|
||||
f"Signal/Wait mismatch on {channel_type}: rank {src_rank} issues {signals} "
|
||||
f"signal(s) to rank {dst_rank}, but rank {dst_rank} performs {waits} wait(s) "
|
||||
f"for rank {src_rank}. Every signal must be matched by a corresponding wait "
|
||||
f"on the peer rank over a channel of the same type."
|
||||
)
|
||||
|
||||
def post_process_operations(self):
|
||||
self.validate_signal_wait_pairing()
|
||||
for gpu in self.gpus:
|
||||
if self.instr_fusion:
|
||||
gpu.optimize_operations()
|
||||
|
||||
@@ -127,11 +127,11 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr<void> ct
|
||||
if ((char*)input == (char*)output + rank * inputSize) {
|
||||
allgatherFullmesh<false><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
|
||||
(void*)input, this->scratchBuffer_, (void*)output, ctx->memoryChannelDeviceHandles.get(), rank,
|
||||
ctx->nRanksPerIpcDomain, ctx->workSize, nElem);
|
||||
ctx->nRanksPerIpcDomain, ctx->worldSize, nElem);
|
||||
} else {
|
||||
allgatherFullmesh<true><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
|
||||
(void*)input, this->scratchBuffer_, (void*)output, ctx->memoryChannelDeviceHandles.get(), rank,
|
||||
ctx->nRanksPerIpcDomain, ctx->workSize, nElem);
|
||||
ctx->nRanksPerIpcDomain, ctx->worldSize, nElem);
|
||||
}
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
@@ -147,7 +147,7 @@ std::shared_ptr<void> AllgatherFullmesh::initAllgatherContext(std::shared_ptr<Co
|
||||
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
// setup semaphores
|
||||
|
||||
@@ -139,11 +139,11 @@ CommResult AllgatherFullmesh2::allgatherKernelFunc(const std::shared_ptr<void> c
|
||||
size_t channelOutOffset = *static_cast<size_t*>(ctx->extras["channel_out_offset"].get());
|
||||
if ((char*)input == (char*)output + rank * inputSize) {
|
||||
allgatherFullmesh2<false><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
|
||||
(void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->workSize,
|
||||
(void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->worldSize,
|
||||
ctx->nRanksPerIpcDomain, nElem);
|
||||
} else {
|
||||
allgatherFullmesh2<true><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
|
||||
(void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->workSize,
|
||||
(void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->worldSize,
|
||||
ctx->nRanksPerIpcDomain, nElem);
|
||||
}
|
||||
cudaError_t err = cudaGetLastError();
|
||||
@@ -158,7 +158,7 @@ std::shared_ptr<void> AllgatherFullmesh2::initAllgatherContext(std::shared_ptr<m
|
||||
void* output, size_t inputSize, mscclpp::DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
// setup semaphores
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "allreduce/allreduce_allpair_packet.hpp"
|
||||
#include "allreduce/common.hpp"
|
||||
#include "collective_utils.hpp"
|
||||
#include "debug.h"
|
||||
#include "logger.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
@@ -24,22 +24,30 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
|
||||
size_t scratchBaseOffset = (flag % numScratchBuff) ? (scratchBufferSize / numScratchBuff) : 0;
|
||||
size_t channelScratchOffset = scratchBaseOffset;
|
||||
|
||||
const int nBlocksPerPeer = gridDim.x / nPeers;
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
size_t srcOffset = channelDataOffset;
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
size_t scratchOffset = channelScratchOffset + rank * nelems * sizeof(LL8Packet);
|
||||
void* scratchBuff = (void*)((char*)scratch + channelScratchOffset);
|
||||
uint32_t* src = (uint32_t*)((char*)buff);
|
||||
uint32_t* dst = (uint32_t*)((char*)resultBuff);
|
||||
|
||||
// step 1: write data to each peer's scratch buffer
|
||||
memoryChannels[peerIdx].putPackets<LL8Packet>(scratchOffset, srcOffset, nelems * sizeof(uint32_t), tid,
|
||||
blockDim.x * nBlocksPerPeer, flag);
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
const int lane = threadIdx.x % WARP_SIZE;
|
||||
const int nWarpsPerBlock = blockDim.x / WARP_SIZE;
|
||||
// Assign one warp in every block to each peer. Each peer warp sends the
|
||||
// same block-owned stripe, so nBlocks only partitions data and no longer
|
||||
// needs to be grouped by nPeers.
|
||||
if (warpId < nPeers) {
|
||||
memoryChannels[warpId].putPackets<LL8Packet>(scratchOffset, channelDataOffset, nelems * sizeof(uint32_t),
|
||||
lane + blockIdx.x * WARP_SIZE, gridDim.x * WARP_SIZE, flag);
|
||||
}
|
||||
// Safe for in-place allreduce: all peer warps must finish reading src for
|
||||
// this block's stripe before any warp writes reduced data back to dst/src.
|
||||
__syncthreads();
|
||||
|
||||
// step 2: Reduce Data
|
||||
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) {
|
||||
// Split the same sent stream across all warps for reduction. warpId selects
|
||||
// which strided subset to reduce while lane preserves coalesced packet reads.
|
||||
for (size_t idx = lane + blockIdx.x * WARP_SIZE + warpId * WARP_SIZE * gridDim.x; idx < nelems;
|
||||
idx += nWarpsPerBlock * WARP_SIZE * gridDim.x) {
|
||||
uint32_t data = src[idx];
|
||||
using AccRaw = std::conditional_t<std::is_same_v<T, AccumT>, uint32_t,
|
||||
mscclpp::VectorType<AccumT, sizeof(uint32_t) / sizeof(T)>>;
|
||||
@@ -56,16 +64,16 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
|
||||
if (threadIdx.x == 0) {
|
||||
((uint32_t*)flags)[blockIdx.x] = flag + 1;
|
||||
}
|
||||
if (blockIdx.x == 0 && threadIdx.x >= gridDim.x && threadIdx.x < flagSize / sizeof(uint32_t)) {
|
||||
((uint32_t*)flags)[threadIdx.x] = flag + 1;
|
||||
if (tid >= gridDim.x && tid < flagSize / sizeof(uint32_t)) {
|
||||
((uint32_t*)flags)[tid] = flag + 1;
|
||||
}
|
||||
}
|
||||
|
||||
inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int worldSize) {
|
||||
if (inputSize < worldSize * sizeof(int)) {
|
||||
return {worldSize - 1, 32};
|
||||
inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int nRanksPerIpcDomain) {
|
||||
if (inputSize < nRanksPerIpcDomain * sizeof(int)) {
|
||||
return {nRanksPerIpcDomain - 1, (nRanksPerIpcDomain - 1) * WARP_SIZE};
|
||||
}
|
||||
return {(worldSize - 1) * 4, 512};
|
||||
return {(nRanksPerIpcDomain - 1) * 4, 512};
|
||||
}
|
||||
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
@@ -77,9 +85,6 @@ struct AllpairAdapter {
|
||||
int nThreadsPerBlock = 0) {
|
||||
using ChannelType = DeviceHandle<MemoryChannel>;
|
||||
const size_t nelems = inputSize / sizeof(T);
|
||||
// Round nBlocks to multiple of nPeers so every block maps to a valid peer.
|
||||
const int nPeers = nRanksPerIpcDomain - 1;
|
||||
nBlocks = (nBlocks / nPeers) * nPeers;
|
||||
allreduceAllPairs<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
|
||||
nRanksPerIpcDomain, worldSize, nelems, numScratchBuff, flags, flagSize);
|
||||
@@ -101,18 +106,27 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
|
||||
const std::unordered_map<std::string, uintptr_t>&,
|
||||
DataType accumDtype) {
|
||||
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
|
||||
if (algoCtx->workSize != algoCtx->nRanksPerIpcDomain) {
|
||||
WARN("AllreduceAllpairPacket requires workSize to match nRanksPerIpcDomain, got workSize=%d, nRanksPerIpcDomain=%d",
|
||||
algoCtx->workSize, algoCtx->nRanksPerIpcDomain);
|
||||
if (algoCtx->worldSize != algoCtx->nRanksPerIpcDomain) {
|
||||
WARN(ALGO,
|
||||
"AllreduceAllpairPacket requires worldSize to match nRanksPerIpcDomain, got worldSize=", algoCtx->worldSize,
|
||||
", nRanksPerIpcDomain=", algoCtx->nRanksPerIpcDomain);
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
std::pair<int, int> blockAndThreadNum{nBlocks, nThreadsPerBlock};
|
||||
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
|
||||
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, algoCtx->nRanksPerIpcDomain);
|
||||
}
|
||||
// nBlocks must be at least nPeers for allpair — each block maps to one peer.
|
||||
if (blockAndThreadNum.first > maxBlockNum_) {
|
||||
WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ",
|
||||
maxBlockNum_, ".");
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
const int nPeers = algoCtx->nRanksPerIpcDomain - 1;
|
||||
if (blockAndThreadNum.first < nPeers) {
|
||||
// The kernel maps peer sends by warpId, so every peer needs a full warp.
|
||||
if (blockAndThreadNum.second % WARP_SIZE != 0 || blockAndThreadNum.second / WARP_SIZE < nPeers) {
|
||||
WARN(ALGO,
|
||||
"Allpair packet requires at least one full warp per peer, but got nThreadsPerBlock=", blockAndThreadNum.second,
|
||||
" and nPeers=", nPeers, ".");
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
size_t sendBytes;
|
||||
@@ -122,16 +136,17 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
|
||||
|
||||
AllreduceFunc allreduce = dispatch<AllpairAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast<int>(dtype));
|
||||
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
|
||||
", dtype=", static_cast<int>(dtype));
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
cudaError_t error =
|
||||
allreduce(input, this->scratchBuffer_, output, algoCtx->memoryChannelDeviceHandles.get(), nullptr, nullptr,
|
||||
nullptr, channelInOffset, 0, this->scratchBufferSize_, algoCtx->rank, algoCtx->nRanksPerIpcDomain,
|
||||
algoCtx->workSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_,
|
||||
algoCtx->worldSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_,
|
||||
this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN("AllreducePacket failed with error: %s", cudaGetErrorString(error));
|
||||
WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error));
|
||||
return CommResult::CommUnhandledCudaError;
|
||||
}
|
||||
return CommResult::CommSuccess;
|
||||
@@ -142,7 +157,7 @@ std::shared_ptr<void> AllreduceAllpairPacket::initAllreduceContext(std::shared_p
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
const int nChannelsPerConnection = maxBlockNum_;
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
ctx->memorySemaphores = this->memorySemaphores_;
|
||||
ctx->registeredMemories = this->registeredMemories_;
|
||||
|
||||
@@ -223,7 +223,7 @@ CommResult AllreduceFullmesh::allreduceKernelFunc(
|
||||
}
|
||||
cudaError_t error =
|
||||
allreduce(input, this->scratchBuffer_, output, inputChannelHandles.get(), ctx->memoryChannelDeviceHandles.get(),
|
||||
nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize,
|
||||
nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize,
|
||||
stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN("AllreduceAllconnect failed with error: %s", cudaGetErrorString(error));
|
||||
@@ -249,7 +249,7 @@ std::shared_ptr<void> AllreduceFullmesh::initAllreduceContext(std::shared_ptr<Co
|
||||
void* output, size_t size, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
// setup semaphores
|
||||
|
||||
@@ -205,7 +205,7 @@ CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(
|
||||
}
|
||||
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr,
|
||||
ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_,
|
||||
ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, stream, nullptr, 0, 0,
|
||||
ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0,
|
||||
blockAndThreadNum.first, blockAndThreadNum.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN("AllreduceNvlsBlockPipeline failed with error: %s", cudaGetErrorString(error));
|
||||
@@ -222,7 +222,7 @@ std::shared_ptr<void> AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar
|
||||
void*, size_t, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
// setup channels
|
||||
|
||||
@@ -93,7 +93,7 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
|
||||
size_t, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
// setup channels
|
||||
@@ -123,7 +123,7 @@ CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<void>
|
||||
}
|
||||
cudaError_t error =
|
||||
allreduce(input, this->scratchBuffer_, output, nullptr, nullptr, ctx->switchChannelDeviceHandles.get(), nullptr,
|
||||
0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, stream,
|
||||
0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream,
|
||||
(void*)flagBuffer_, (uint32_t)flagBufferSize_, 0, blockAndThreadNum.first, blockAndThreadNum.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN(ALGO, "AllreduceNvlsPacket failed with error: ", cudaGetErrorString(error));
|
||||
|
||||
@@ -169,7 +169,7 @@ CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(
|
||||
}
|
||||
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr,
|
||||
ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_,
|
||||
ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, stream, nullptr, 0, 0,
|
||||
ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0,
|
||||
blockAndThreadNum.first, blockAndThreadNum.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN("AllreduceNvlsWarpPipeline failed with error: %s", cudaGetErrorString(error));
|
||||
@@ -186,7 +186,7 @@ std::shared_ptr<void> AllreduceNvlsWarpPipeline::initAllreduceContext(std::share
|
||||
void*, size_t, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
// setup channels
|
||||
|
||||
@@ -149,17 +149,17 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_vo
|
||||
// the number of GPUs. Empirically, 32 blocks works well for 4 GPUs and 16 for 8 GPUs, which
|
||||
// follows the formula 128 / nGPUs, clamped to [1, MAX_NBLOCKS].
|
||||
if (computeCapabilityMajor_ == 10) {
|
||||
numBlocksAndThreads.first = ::max(1, ::min(128 / ctx->workSize, MAX_NBLOCKS));
|
||||
numBlocksAndThreads.first = ::max(1, ::min(128 / ctx->worldSize, MAX_NBLOCKS));
|
||||
}
|
||||
}
|
||||
if (numBlocksAndThreads.first > MAX_NBLOCKS) {
|
||||
WARN("Number of blocks exceeds maximum supported value of %d", MAX_NBLOCKS);
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
cudaError_t error =
|
||||
allreduce(nullptr, nullptr, nullptr, this->memoryChannelsDeviceHandle_.get(), nullptr, nvlsChannels,
|
||||
nvlsOutChannels, channelInOffset, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain,
|
||||
ctx->workSize, inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
|
||||
cudaError_t error = allreduce(nullptr, nullptr, nullptr, this->memoryChannelsDeviceHandle_.get(), nullptr,
|
||||
nvlsChannels, nvlsOutChannels, channelInOffset, channelOutOffset, 0, ctx->rank,
|
||||
ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0,
|
||||
numBlocksAndThreads.first, numBlocksAndThreads.second);
|
||||
if (error != cudaSuccess) {
|
||||
if (error == cudaErrorNotSupported) {
|
||||
WARN("AllreduceNvls does not support the requested data type.");
|
||||
@@ -185,7 +185,7 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
|
||||
const void* input, void* output, size_t, mscclpp::DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
size_t sendBytes, recvBytes;
|
||||
|
||||
@@ -230,20 +230,25 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_
|
||||
const std::unordered_map<std::string, uintptr_t>&,
|
||||
DataType accumDtype) {
|
||||
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
|
||||
if (ctx->workSize != ctx->nRanksPerIpcDomain) {
|
||||
WARN(ALGO, "AllreducePacket requires workSize to match nRanksPerIpcDomain, got workSize=", ctx->workSize,
|
||||
if (ctx->worldSize != ctx->nRanksPerIpcDomain) {
|
||||
WARN(ALGO, "AllreducePacket requires worldSize to match nRanksPerIpcDomain, got worldSize=", ctx->worldSize,
|
||||
", nRanksPerIpcDomain=", ctx->nRanksPerIpcDomain);
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
std::pair<int, int> blockAndThreadNum = {nBlocks, nThreadsPerBlock};
|
||||
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
|
||||
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->nRanksPerIpcDomain, ctx->workSize, dtype);
|
||||
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->nRanksPerIpcDomain, ctx->worldSize, dtype);
|
||||
} else {
|
||||
const int nPeers = ctx->nRanksPerIpcDomain - 1;
|
||||
if (blockAndThreadNum.first < nPeers) {
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
}
|
||||
if (blockAndThreadNum.first > maxBlockNum_) {
|
||||
WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ",
|
||||
maxBlockNum_, ".");
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
|
||||
size_t sendBytes;
|
||||
CUdeviceptr sendBasePtr;
|
||||
@@ -258,7 +263,7 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_
|
||||
}
|
||||
cudaError_t error =
|
||||
allreduce(input, this->scratchBuffer_, output, ctx->memoryChannelDeviceHandles.get(), nullptr, nullptr, nullptr,
|
||||
channelInOffset, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize,
|
||||
channelInOffset, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize,
|
||||
inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_,
|
||||
blockAndThreadNum.first, blockAndThreadNum.second);
|
||||
if (error != cudaSuccess) {
|
||||
@@ -273,7 +278,7 @@ std::shared_ptr<void> AllreducePacket::initAllreduceContext(std::shared_ptr<Comm
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
const int nChannelsPerConnection = maxBlockNum_;
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
ctx->memorySemaphores = this->memorySemaphores_;
|
||||
ctx->registeredMemories = this->registeredMemories_;
|
||||
|
||||
@@ -185,7 +185,7 @@ CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr<void> ctx, c
|
||||
}
|
||||
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(),
|
||||
this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, 0, algoCtx->rank,
|
||||
algoCtx->nRanksPerIpcDomain, algoCtx->workSize, inputSize, stream, nullptr, 0, 0,
|
||||
algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, stream, nullptr, 0, 0,
|
||||
numBlocksAndThreads.first, numBlocksAndThreads.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
|
||||
@@ -202,7 +202,7 @@ std::shared_ptr<void> AllreduceRsAg::initAllreduceContext(std::shared_ptr<Commun
|
||||
size_t, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
ctx->memorySemaphores = this->scratchSemaphores_;
|
||||
|
||||
@@ -288,7 +288,7 @@ CommResult AllreduceRsAgPipeline::allreduceKernelFunc(
|
||||
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
|
||||
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(),
|
||||
this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, this->scratchBufferSize_,
|
||||
algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->workSize, inputSize, stream,
|
||||
algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, stream,
|
||||
nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
|
||||
if (error != cudaSuccess) {
|
||||
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
|
||||
@@ -305,7 +305,7 @@ std::shared_ptr<void> AllreduceRsAgPipeline::initAllreduceContext(std::shared_pt
|
||||
void*, size_t, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
ctx->memorySemaphores = this->scratchSemaphores_;
|
||||
|
||||
@@ -172,7 +172,7 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void
|
||||
}
|
||||
cudaError_t error =
|
||||
allreduce(input, nullptr, output, this->baseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(),
|
||||
nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->workSize, inputSize,
|
||||
nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize,
|
||||
stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
|
||||
if (error != cudaSuccess) {
|
||||
if (error == cudaErrorInvalidValue) {
|
||||
@@ -203,7 +203,7 @@ std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt
|
||||
void* output, size_t size, DataType) {
|
||||
auto ctx = std::make_shared<AlgorithmCtx>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
|
||||
|
||||
ctx->memorySemaphores = this->semaphores_;
|
||||
|
||||
@@ -30,9 +30,7 @@ class AllreduceAllpairPacket : public AlgorithmBuilder {
|
||||
void* scratchBuffer_;
|
||||
size_t scratchBufferSize_;
|
||||
const int nSegmentsForScratchBuffer_ = 2;
|
||||
// Must be at least MAX_IPC_DOMAIN_NRANKS-1 so the adapter can launch one
|
||||
// block per peer at MNNVL scale.
|
||||
const int maxBlockNum_ = MAX_IPC_DOMAIN_NRANKS - 1;
|
||||
const int maxBlockNum_ = 64;
|
||||
std::vector<Connection> conns_;
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores_;
|
||||
std::vector<RegisteredMemory> registeredMemories_;
|
||||
|
||||
@@ -79,7 +79,7 @@ std::shared_ptr<DeviceHandle<BaseMemoryChannel>> setupBaseMemoryChannelDeviceHan
|
||||
class AlgorithmCtx {
|
||||
public:
|
||||
int rank;
|
||||
int workSize;
|
||||
int worldSize;
|
||||
int nRanksPerIpcDomain;
|
||||
|
||||
std::vector<RegisteredMemory> registeredMemories;
|
||||
|
||||
Reference in New Issue
Block a user