mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Support hybrid connections for single and multi node
This commit is contained in:
@@ -8,7 +8,10 @@
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@@ -27,10 +30,25 @@ struct AllToAllVContext {
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
|
||||
// Intra-node (CudaIpc) channels — MemoryChannel for direct NVLink copy
|
||||
std::vector<RegisteredMemory> registeredMemories;
|
||||
std::vector<MemoryChannel> memoryChannels;
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores;
|
||||
std::shared_ptr<DeviceHandle<MemoryChannel>> memoryChannelDeviceHandles;
|
||||
|
||||
// Inter-node (IB) channels — PortChannel via ProxyService
|
||||
std::shared_ptr<ProxyService> proxyService;
|
||||
std::vector<PortChannel> portChannels;
|
||||
std::shared_ptr<PortChannelDeviceHandle> portChannelDeviceHandles;
|
||||
|
||||
// Peer locality map: peerIsLocal[peerIdx] = 1 if intra-node, 0 if inter-node
|
||||
// peerIdx is the index into the channel arrays (0..nPeers-1), NOT the rank
|
||||
std::shared_ptr<int> d_peerIsLocal; // GPU array [nPeers]
|
||||
// For inter-node peers, maps peerIdx → portChannel index (dense indexing)
|
||||
std::shared_ptr<int> d_peerToPortChannelIdx; // GPU array [nPeers]
|
||||
|
||||
bool hasRemotePeers; // true if any inter-node connections exist
|
||||
|
||||
std::shared_ptr<DeviceSyncer> deviceSyncer; // GPU-allocated, for multi-block grid sync
|
||||
};
|
||||
|
||||
@@ -68,12 +86,34 @@ std::shared_ptr<Algorithm> AlltoallvFullmesh::build() {
|
||||
|
||||
void AlltoallvFullmesh::initialize(std::shared_ptr<Communicator> comm) {
|
||||
worldSize_ = comm->bootstrap()->getNranks();
|
||||
this->conns_ = setupConnections(comm);
|
||||
int rank = comm->bootstrap()->getRank();
|
||||
int nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
int localGpuIdx = rank % nRanksPerNode;
|
||||
|
||||
// Use hybrid connections: CudaIpc for intra-node, IB for inter-node
|
||||
bool hasIB = getIBDeviceCount() > 0;
|
||||
bool isMultiNode = (worldSize_ > nRanksPerNode);
|
||||
|
||||
if (hasIB && isMultiNode) {
|
||||
this->conns_ = setupHybridConnections(comm, localGpuIdx);
|
||||
// Check if any connections are actually inter-node
|
||||
hasRemotePeers_ = false;
|
||||
for (const auto& conn : this->conns_) {
|
||||
if (!isIntraNodeConnection(conn)) {
|
||||
hasRemotePeers_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single-node or no IB: use CudaIpc for all
|
||||
this->conns_ = setupConnections(comm);
|
||||
hasRemotePeers_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream,
|
||||
[[maybe_unused]] size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream,
|
||||
[[maybe_unused]] int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
|
||||
@@ -103,21 +143,20 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
// Use maximum threads (1024) for best bandwidth utilization
|
||||
const int threadsPerBlock = (nThreadsPerBlock > 0 && nThreadsPerBlock <= 1024) ? nThreadsPerBlock : 1024;
|
||||
|
||||
// Peer-parallel algorithm: blocks assigned round-robin to peers so ALL
|
||||
// NVLink connections are active simultaneously. Critical for 4+ GPU systems.
|
||||
//
|
||||
// Small messages (<1MB avg): nPeers blocks (1 per peer, no barrier)
|
||||
// Large messages (>=1MB avg): nPeers * blocksPerPeer (barrier-based)
|
||||
constexpr size_t SIZE_THRESHOLD = 1 << 20; // 1MB
|
||||
size_t avgMsgSize = inputSize / worldSize;
|
||||
int nPeers = worldSize - 1;
|
||||
if (nPeers < 1) nPeers = 1;
|
||||
|
||||
if (avgMsgSize < SIZE_THRESHOLD) {
|
||||
// Small messages: 1 block per peer, parallel signal/wait, no barrier
|
||||
if (algoCtx->hasRemotePeers) {
|
||||
// Multi-node: use hybrid kernel with MemoryChannel (intra) + PortChannel (inter)
|
||||
// PortChannel put() is single-threaded (FIFO push), so we use 1 block per peer.
|
||||
// For large intra-node messages, multiple blocks per local peer would help,
|
||||
// but keeping it simple for now: 1 block per peer for both local and remote.
|
||||
int numBlocks = nPeers;
|
||||
alltoallvPeerParallelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
alltoallvHybridKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
algoCtx->portChannelDeviceHandles.get(),
|
||||
algoCtx->d_peerIsLocal.get(),
|
||||
algoCtx->d_peerToPortChannelIdx.get(),
|
||||
algoCtx->deviceSyncer.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
@@ -125,22 +164,38 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
d_recvCounts, d_recvDispls,
|
||||
d_remoteRecvDispls);
|
||||
} else {
|
||||
// Large messages: multiple blocks per peer for maximum put bandwidth.
|
||||
// Cap total blocks to avoid excessive barrier overhead.
|
||||
int blocksPerPeer = (nBlocks > 0 && nBlocks <= 128)
|
||||
? ((nBlocks + nPeers - 1) / nPeers) // user-specified total → per-peer
|
||||
: ALLTOALLV_DEFAULT_BLOCKS_PER_PEER;
|
||||
int numBlocks = nPeers * blocksPerPeer;
|
||||
if (numBlocks > 128) numBlocks = (128 / nPeers) * nPeers; // keep multiple of nPeers
|
||||
if (numBlocks < nPeers) numBlocks = nPeers;
|
||||
alltoallvPeerParallelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
algoCtx->deviceSyncer.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls,
|
||||
d_remoteRecvDispls);
|
||||
// Single-node: use the optimized peer-parallel kernel (MemoryChannel only)
|
||||
constexpr size_t SIZE_THRESHOLD = 1 << 20; // 1MB
|
||||
size_t avgMsgSize = inputSize / worldSize;
|
||||
|
||||
if (avgMsgSize < SIZE_THRESHOLD) {
|
||||
// Small messages: 1 block per peer, parallel signal/wait, no barrier
|
||||
int numBlocks = nPeers;
|
||||
alltoallvPeerParallelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
algoCtx->deviceSyncer.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls,
|
||||
d_remoteRecvDispls);
|
||||
} else {
|
||||
// Large messages: multiple blocks per peer for maximum put bandwidth.
|
||||
int blocksPerPeer = (nBlocks > 0 && nBlocks <= 128)
|
||||
? ((nBlocks + nPeers - 1) / nPeers)
|
||||
: ALLTOALLV_DEFAULT_BLOCKS_PER_PEER;
|
||||
int numBlocks = nPeers * blocksPerPeer;
|
||||
if (numBlocks > 128) numBlocks = (128 / nPeers) * nPeers;
|
||||
if (numBlocks < nPeers) numBlocks = nPeers;
|
||||
alltoallvPeerParallelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
algoCtx->deviceSyncer.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls,
|
||||
d_remoteRecvDispls);
|
||||
}
|
||||
}
|
||||
|
||||
if (cudaGetLastError() == cudaSuccess) {
|
||||
@@ -157,29 +212,65 @@ std::shared_ptr<void> AlltoallvFullmesh::initAlltoallvContext(
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
ctx->hasRemotePeers = this->hasRemotePeers_;
|
||||
|
||||
int rank = ctx->rank;
|
||||
int nRanksPerNode = ctx->nRanksPerNode;
|
||||
int localGpuIdx = rank % nRanksPerNode;
|
||||
|
||||
// Determine transport flags for memory registration.
|
||||
// If we have remote peers, register with both CudaIpc and IB transports.
|
||||
TransportFlags allTransports = Transport::CudaIpc;
|
||||
if (ctx->hasRemotePeers) {
|
||||
allTransports |= getIBTransportForGpu(localGpuIdx);
|
||||
}
|
||||
|
||||
// Register memories for input and output buffers
|
||||
RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, Transport::CudaIpc);
|
||||
RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, Transport::CudaIpc);
|
||||
RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, allTransports);
|
||||
RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, allTransports);
|
||||
|
||||
// Exchange output buffer registration with all peers (we write to peer's output buffer)
|
||||
std::vector<RegisteredMemory> remoteOutputMemories = setupRemoteMemories(comm, ctx->rank, outputBufRegMem);
|
||||
std::vector<RegisteredMemory> remoteOutputMemories = setupRemoteMemories(comm, rank, outputBufRegMem);
|
||||
|
||||
// Setup memory semaphores for synchronization (1 channel per peer)
|
||||
// Build peer locality map and channel index mappings
|
||||
int nPeers = ctx->worldSize - 1;
|
||||
std::vector<int> peerIsLocal(nPeers, 1);
|
||||
std::vector<int> peerToPortChannelIdx(nPeers, -1);
|
||||
int portChannelCount = 0;
|
||||
|
||||
for (size_t cid = 0; cid < this->conns_.size(); ++cid) {
|
||||
if (!isIntraNodeConnection(this->conns_[cid])) {
|
||||
peerIsLocal[cid] = 0;
|
||||
peerToPortChannelIdx[cid] = portChannelCount++;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup intra-node MemoryChannels (CudaIpc connections only)
|
||||
constexpr int nChannelsPerConnection = 1;
|
||||
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);
|
||||
|
||||
// Setup memory channels: we read from our input buffer, write to peer's output buffer
|
||||
ctx->memoryChannels = setupMemoryChannels(
|
||||
this->conns_,
|
||||
ctx->memorySemaphores,
|
||||
remoteOutputMemories, // remote output buffers (where we write)
|
||||
inputBufRegMem, // local input buffer (where we read from)
|
||||
nChannelsPerConnection);
|
||||
|
||||
// Setup device handles
|
||||
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
|
||||
|
||||
// Setup inter-node PortChannels (IB connections only)
|
||||
if (ctx->hasRemotePeers) {
|
||||
ctx->proxyService = std::make_shared<ProxyService>();
|
||||
ctx->portChannels = setupPortChannels(
|
||||
ctx->proxyService, *comm, this->conns_, remoteOutputMemories, inputBufRegMem);
|
||||
ctx->portChannelDeviceHandles = setupPortChannelDeviceHandles(ctx->portChannels);
|
||||
ctx->proxyService->startProxy(true);
|
||||
}
|
||||
|
||||
// Copy peer locality info to GPU
|
||||
ctx->d_peerIsLocal = mscclpp::detail::gpuCallocShared<int>(nPeers);
|
||||
mscclpp::gpuMemcpy<int>(ctx->d_peerIsLocal.get(), peerIsLocal.data(), nPeers, cudaMemcpyHostToDevice);
|
||||
ctx->d_peerToPortChannelIdx = mscclpp::detail::gpuCallocShared<int>(nPeers);
|
||||
mscclpp::gpuMemcpy<int>(ctx->d_peerToPortChannelIdx.get(), peerToPortChannelIdx.data(), nPeers, cudaMemcpyHostToDevice);
|
||||
|
||||
// Allocate GPU DeviceSyncer for multi-block grid-wide barrier (zero-initialized)
|
||||
ctx->deviceSyncer = mscclpp::detail::gpuCallocShared<DeviceSyncer>();
|
||||
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/switch_channel.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
namespace collective {
|
||||
@@ -54,6 +56,80 @@ std::vector<mscclpp::Connection> setupConnections(std::shared_ptr<mscclpp::Commu
|
||||
return connections;
|
||||
}
|
||||
|
||||
// IB device array — GPU index maps to its dedicated IB device
|
||||
static const mscclpp::Transport IBs[] = {
|
||||
mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, mscclpp::Transport::IB3,
|
||||
mscclpp::Transport::IB4, mscclpp::Transport::IB5, mscclpp::Transport::IB6, mscclpp::Transport::IB7,
|
||||
};
|
||||
|
||||
mscclpp::Transport getIBTransportForGpu(int localGpuIdx) {
|
||||
int ibCount = mscclpp::getIBDeviceCount();
|
||||
if (ibCount <= 0) {
|
||||
throw std::runtime_error("No IB devices available for inter-node communication");
|
||||
}
|
||||
int idx = localGpuIdx % ibCount;
|
||||
return IBs[idx];
|
||||
}
|
||||
|
||||
std::vector<mscclpp::Connection> setupHybridConnections(std::shared_ptr<mscclpp::Communicator> comm,
|
||||
int localGpuIdx) {
|
||||
int rank = comm->bootstrap()->getRank();
|
||||
int worldSize = comm->bootstrap()->getNranks();
|
||||
int nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
int thisNode = rank / nRanksPerNode;
|
||||
|
||||
bool hasIB = mscclpp::getIBDeviceCount() > 0;
|
||||
mscclpp::Transport ibTransport = hasIB ? getIBTransportForGpu(localGpuIdx) : mscclpp::Transport::CudaIpc;
|
||||
|
||||
std::vector<std::shared_future<mscclpp::Connection>> connectionFutures;
|
||||
for (int r = 0; r < worldSize; r++) {
|
||||
if (r == rank) continue;
|
||||
mscclpp::Transport transport;
|
||||
if (r / nRanksPerNode == thisNode) {
|
||||
transport = mscclpp::Transport::CudaIpc;
|
||||
} else {
|
||||
transport = ibTransport;
|
||||
}
|
||||
connectionFutures.push_back(comm->connect(transport, r));
|
||||
}
|
||||
|
||||
std::vector<mscclpp::Connection> connections;
|
||||
std::transform(connectionFutures.begin(), connectionFutures.end(), std::back_inserter(connections),
|
||||
[](const auto& future) { return future.get(); });
|
||||
return connections;
|
||||
}
|
||||
|
||||
std::vector<mscclpp::PortChannel> setupPortChannels(
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService,
|
||||
mscclpp::Communicator& comm,
|
||||
const std::vector<mscclpp::Connection>& connections,
|
||||
const std::vector<mscclpp::RegisteredMemory>& remoteMemories,
|
||||
mscclpp::RegisteredMemory localMemory) {
|
||||
std::vector<mscclpp::PortChannel> channels;
|
||||
mscclpp::MemoryId srcMemId = proxyService->addMemory(localMemory);
|
||||
for (size_t cid = 0; cid < connections.size(); ++cid) {
|
||||
if (connections[cid].transport() != mscclpp::Transport::CudaIpc) {
|
||||
// IB connection → PortChannel
|
||||
mscclpp::SemaphoreId semId = proxyService->buildAndAddSemaphore(comm, connections[cid]);
|
||||
mscclpp::MemoryId dstMemId = proxyService->addMemory(remoteMemories[cid]);
|
||||
channels.emplace_back(proxyService->portChannel(semId, dstMemId, srcMemId));
|
||||
}
|
||||
}
|
||||
return channels;
|
||||
}
|
||||
|
||||
std::shared_ptr<mscclpp::PortChannelDeviceHandle> setupPortChannelDeviceHandles(
|
||||
const std::vector<mscclpp::PortChannel>& portChannels) {
|
||||
if (portChannels.empty()) return nullptr;
|
||||
std::vector<mscclpp::PortChannelDeviceHandle> handles;
|
||||
std::transform(portChannels.begin(), portChannels.end(), std::back_inserter(handles),
|
||||
[](const mscclpp::PortChannel& ch) { return ch.deviceHandle(); });
|
||||
auto ptr = mscclpp::detail::gpuCallocShared<mscclpp::PortChannelDeviceHandle>(handles.size());
|
||||
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
|
||||
ptr.get(), handles.data(), handles.size(), cudaMemcpyHostToDevice);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
|
||||
std::shared_ptr<mscclpp::Communicator> comm, const std::vector<mscclpp::Connection>& connections,
|
||||
int nChannelsPerConnection) {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <mscclpp/algorithm.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -50,6 +51,7 @@ class AlltoallvFullmesh : public AlgorithmBuilder {
|
||||
|
||||
std::vector<Connection> conns_;
|
||||
int worldSize_;
|
||||
bool hasRemotePeers_; // true if any inter-node connections
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/copy_device.hpp>
|
||||
|
||||
@@ -29,6 +30,117 @@ constexpr int ALLTOALLV_DEFAULT_NBLOCKS = 24;
|
||||
// Controls how many thread blocks cooperate on each peer's data transfer.
|
||||
constexpr int ALLTOALLV_DEFAULT_BLOCKS_PER_PEER = 16;
|
||||
|
||||
/**
|
||||
* Hybrid AllToAllV kernel for multi-node: MemoryChannel (intra-node) + PortChannel (inter-node).
|
||||
*
|
||||
* Each block handles one peer (1 block per peer). For intra-node peers, all threads
|
||||
* cooperate on a MemoryChannel put (multi-threaded NVLink copy). For inter-node peers,
|
||||
* thread 0 pushes a PortChannel put descriptor to the CPU proxy FIFO (single-threaded),
|
||||
* which triggers an RDMA transfer.
|
||||
*
|
||||
* Key design points:
|
||||
* - MemoryChannel uses peerIdx-based dense indexing (only intra-node peers have MemoryChannels)
|
||||
* but we need the SAME peerIdx ordering as the connection array.
|
||||
* In practice, memoryChannels[] are created only for CudaIpc connections and are dense.
|
||||
* We use a separate peerToMemChIdx mapping from peerIsLocal.
|
||||
* - PortChannel uses separate dense indexing via peerToPortChannelIdx.
|
||||
* - Signal/wait is done per-peer by thread 0 of each block.
|
||||
*
|
||||
* Launch config: <<<nPeers, 1024>>>
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallvHybridKernel(DeviceHandle<MemoryChannel>* memoryChannels,
|
||||
PortChannelDeviceHandle* portChannels,
|
||||
const int* peerIsLocal,
|
||||
const int* peerToPortChannelIdx,
|
||||
DeviceSyncer* syncer,
|
||||
int rank,
|
||||
int worldSize,
|
||||
const void* sendBuff,
|
||||
void* recvBuff,
|
||||
const size_t* sendCounts,
|
||||
const size_t* sendDispls,
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls,
|
||||
const size_t* remoteRecvDispls) {
|
||||
const int nPeers = worldSize - 1;
|
||||
|
||||
// Handle trivial case (single rank)
|
||||
if (nPeers == 0) {
|
||||
const int gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int nThreads = blockDim.x * gridDim.x;
|
||||
if (sendCounts[rank] > 0) {
|
||||
mscclpp::copy((char*)recvBuff + recvDispls[rank],
|
||||
(void*)((const char*)sendBuff + sendDispls[rank]),
|
||||
sendCounts[rank], gtid, nThreads);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Phase 1: Local copy — all blocks cooperate using global thread IDs
|
||||
const int gtid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int nThreads = blockDim.x * gridDim.x;
|
||||
if (sendCounts[rank] > 0) {
|
||||
mscclpp::copy((char*)recvBuff + recvDispls[rank],
|
||||
(void*)((const char*)sendBuff + sendDispls[rank]),
|
||||
sendCounts[rank], gtid, nThreads);
|
||||
}
|
||||
|
||||
// Phase 2: Per-peer data transfer.
|
||||
// Each block handles one peer: blockIdx.x == peerIdx
|
||||
const int peerIdx = blockIdx.x;
|
||||
if (peerIdx >= nPeers) return;
|
||||
|
||||
const int peer = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
|
||||
if (peerIsLocal[peerIdx]) {
|
||||
// Intra-node: MemoryChannel — all threads cooperate on multi-threaded put
|
||||
// MemoryChannels are densely indexed for CudaIpc connections only.
|
||||
// We need to compute the MemoryChannel index from peerIdx.
|
||||
// Count how many local peers are before this peerIdx.
|
||||
int memChIdx = 0;
|
||||
for (int i = 0; i < peerIdx; i++) {
|
||||
if (peerIsLocal[i]) memChIdx++;
|
||||
}
|
||||
|
||||
if (sendCounts[peer] > 0) {
|
||||
memoryChannels[memChIdx].put(
|
||||
remoteRecvDispls[peer], // dst offset in peer's buffer
|
||||
sendDispls[peer], // src offset in our buffer
|
||||
sendCounts[peer], // size
|
||||
threadIdx.x, // thread id within block
|
||||
blockDim.x // total threads for this peer
|
||||
);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Signal and wait (thread 0 only)
|
||||
if (threadIdx.x == 0) {
|
||||
memoryChannels[memChIdx].signal();
|
||||
if (recvCounts[peer] > 0) {
|
||||
memoryChannels[memChIdx].wait();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Inter-node: PortChannel — single-threaded FIFO push
|
||||
int portChIdx = peerToPortChannelIdx[peerIdx];
|
||||
|
||||
if (threadIdx.x == 0 && sendCounts[peer] > 0) {
|
||||
portChannels[portChIdx].putWithSignalAndFlush(
|
||||
remoteRecvDispls[peer], // dst offset
|
||||
sendDispls[peer], // src offset
|
||||
sendCounts[peer] // size
|
||||
);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Wait for incoming data from remote peer
|
||||
if (threadIdx.x == 0 && recvCounts[peer] > 0) {
|
||||
portChannels[portChIdx].wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Peer-parallel AllToAllV kernel for maximum throughput with multiple GPUs.
|
||||
*
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
#include <mscclpp/switch_channel.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
@@ -42,6 +43,45 @@ std::vector<MemoryChannel> setupMemoryChannels(
|
||||
const std::vector<RegisteredMemory>& remoteMemories, RegisteredMemory localMemory, int nChannelsPerConnection);
|
||||
|
||||
std::vector<Connection> setupConnections(std::shared_ptr<Communicator> comm);
|
||||
|
||||
/// Setup connections with hybrid transport: CudaIpc for intra-node, IB for inter-node.
|
||||
/// Dynamically detects if all peers are intra-node (single-node case) and falls back to CudaIpc-only.
|
||||
/// @param comm Communicator
|
||||
/// @param localGpuIdx Local GPU index within the node (used to select IB device)
|
||||
/// @return Vector of connections (one per peer)
|
||||
std::vector<Connection> setupHybridConnections(std::shared_ptr<Communicator> comm, int localGpuIdx);
|
||||
|
||||
/// Check if a connection is intra-node (CudaIpc transport).
|
||||
/// @param conn The connection to check
|
||||
/// @return true if the connection uses CudaIpc transport
|
||||
inline bool isIntraNodeConnection(const Connection& conn) {
|
||||
return conn.transport() == Transport::CudaIpc;
|
||||
}
|
||||
|
||||
/// Get the IB transport for a given local GPU index.
|
||||
/// @param localGpuIdx Local GPU index (0-7)
|
||||
/// @return The corresponding IB transport
|
||||
Transport getIBTransportForGpu(int localGpuIdx);
|
||||
|
||||
/// Setup PortChannels for inter-node connections via ProxyService.
|
||||
/// Creates PortChannels only for IB connections, with MemoryId-based addressing.
|
||||
/// @param proxyService The ProxyService managing IB transfers
|
||||
/// @param comm The communicator
|
||||
/// @param connections All connections (mixed CudaIpc + IB)
|
||||
/// @param remoteMemories Remote registered memories (one per peer)
|
||||
/// @param localMemory Local registered memory
|
||||
/// @return Vector of PortChannels (only for IB peers, in connection order)
|
||||
std::vector<PortChannel> setupPortChannels(
|
||||
std::shared_ptr<ProxyService> proxyService,
|
||||
Communicator& comm,
|
||||
const std::vector<Connection>& connections,
|
||||
const std::vector<RegisteredMemory>& remoteMemories,
|
||||
RegisteredMemory localMemory);
|
||||
|
||||
/// Setup PortChannel device handles (GPU-allocated array).
|
||||
std::shared_ptr<PortChannelDeviceHandle> setupPortChannelDeviceHandles(
|
||||
const std::vector<PortChannel>& portChannels);
|
||||
|
||||
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> setupMemorySemaphores(
|
||||
std::shared_ptr<Communicator> comm, const std::vector<Connection>& connections, int nChannelsPerConnection);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user