mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Optimize alltoallv: hybrid kernel for multi-node IB mode, reduce Python hot-path overhead
This commit is contained in:
@@ -21,9 +21,16 @@ from mscclpp._mscclpp import (
|
||||
TcpBootstrap,
|
||||
DataType,
|
||||
ReduceOp,
|
||||
CommResult,
|
||||
)
|
||||
from mscclpp.ext.algorithm_collection_builder import AlgorithmCollectionBuilder
|
||||
|
||||
import ctypes as _ctypes
|
||||
try:
|
||||
_cudart = _ctypes.CDLL("libcudart.so")
|
||||
except Exception:
|
||||
_cudart = None
|
||||
|
||||
_DEBUG = os.environ.get("MSCCLPP_DEBUG_ALLTOALLV", "0") == "1"
|
||||
|
||||
__all__ = ["MscclppAlltoAllV", "all_to_all_single"]
|
||||
@@ -283,32 +290,19 @@ class MscclppAlltoAllV:
|
||||
stream = torch.cuda.current_stream()
|
||||
cuda_stream = stream.cuda_stream
|
||||
|
||||
# Use the full underlying storage size (not just the view's active data)
|
||||
# for the context key, so that reusing views of the same tensor with
|
||||
# different split sizes doesn't create new contexts (which leak
|
||||
# RegisteredMemory for stale buffers).
|
||||
try:
|
||||
input_alloc_size = input.untyped_storage().size()
|
||||
output_alloc_size = output.untyped_storage().size()
|
||||
except Exception:
|
||||
input_alloc_size = input.nelement() * input.element_size()
|
||||
output_alloc_size = output.nelement() * output.element_size()
|
||||
|
||||
# Execute the optimized kernel
|
||||
# Clear any stale CUDA errors before executing (the C++ code checks
|
||||
# cudaGetLastError() after the kernel and returns INTERNAL_ERROR if any
|
||||
# previous error was pending).
|
||||
torch.cuda.synchronize()
|
||||
# Also clear the CUDA error state via cudaGetLastError (consumes the error)
|
||||
import ctypes
|
||||
try:
|
||||
_cudart = ctypes.CDLL("libcudart.so")
|
||||
_last_err = _cudart.cudaGetLastError()
|
||||
if _last_err != 0 and _DEBUG:
|
||||
print(f" [rank {self._rank}] WARNING: cleared stale CUDA error code {_last_err} before execute", flush=True)
|
||||
except Exception:
|
||||
pass
|
||||
# Use the full underlying storage size for context key stability.
|
||||
# When the test reuses the same large tensor with different split sizes,
|
||||
# storage size stays constant → same context key → reuses channels.
|
||||
input_alloc_size = input.untyped_storage().size()
|
||||
output_alloc_size = output.untyped_storage().size()
|
||||
|
||||
if _DEBUG:
|
||||
# Clear stale CUDA errors (the C++ code checks cudaGetLastError
|
||||
# after the kernel and returns INTERNAL_ERROR if any was pending).
|
||||
if _cudart is not None:
|
||||
_last_err = _cudart.cudaGetLastError()
|
||||
if _last_err != 0:
|
||||
print(f" [rank {self._rank}] WARNING: cleared stale CUDA error code {_last_err} before execute", flush=True)
|
||||
print(f" [rank {self._rank}] alltoallv: calling algo.execute(input_alloc={input_alloc_size}, output_alloc={output_alloc_size})", flush=True)
|
||||
result = self._algo.execute(
|
||||
self._comm,
|
||||
@@ -327,7 +321,6 @@ class MscclppAlltoAllV:
|
||||
if _DEBUG:
|
||||
print(f" [rank {self._rank}] alltoallv: algo.execute returned {result}", flush=True)
|
||||
|
||||
from mscclpp._mscclpp import CommResult
|
||||
if result != CommResult.COMM_SUCCESS:
|
||||
# Get detailed CUDA error before raising
|
||||
try:
|
||||
|
||||
@@ -193,12 +193,16 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
}
|
||||
|
||||
if (algoCtx->mode == MultiNodeMode::IB) {
|
||||
// ── IB mode: PortChannel kernel for ALL peers ──────────────────────
|
||||
// PortChannel handles both CudaIpc (intra) and IB (inter) connections
|
||||
// via the ProxyService proxy thread.
|
||||
// ── IB mode: Hybrid kernel ─────────────────────────────────────────
|
||||
// MemoryChannel (direct NVLink) for intra-node peers,
|
||||
// PortChannel (CPU proxy → RDMA) for inter-node peers.
|
||||
int numBlocks = nPeers;
|
||||
alltoallvPortChannelKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
alltoallvHybridKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
algoCtx->portChannelDeviceHandles.get(),
|
||||
algoCtx->d_peerIsLocal.get(),
|
||||
algoCtx->d_peerToPortChannelIdx.get(),
|
||||
algoCtx->deviceSyncer.get(),
|
||||
rank, worldSize,
|
||||
sendBuff, recvBuff,
|
||||
d_sendCounts, d_sendDispls,
|
||||
@@ -308,25 +312,54 @@ std::shared_ptr<void> AlltoallvFullmesh::initAlltoallvContext(
|
||||
ctx->registeredMemories.push_back(outputBufRegMem);
|
||||
|
||||
} else if (ctx->mode == MultiNodeMode::IB) {
|
||||
// ── IB: PortChannel for ALL peers (CudaIpc intra + IB inter connections)
|
||||
// ── IB hybrid: MemoryChannel (intra-node) + PortChannel (inter-node) ──
|
||||
TransportFlags allTransports = Transport::CudaIpc | getIBTransportForGpu(localGpuIdx);
|
||||
RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, allTransports);
|
||||
RegisteredMemory outputBufRegMem = comm->registerMemory(output, outputSize, allTransports);
|
||||
|
||||
std::vector<RegisteredMemory> remoteOutputMemories = setupRemoteMemories(comm, rank, outputBufRegMem);
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: input=%p (%zu B), output=%p (%zu B), remotes=%zu",
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB hybrid: input=%p (%zu B), output=%p (%zu B), remotes=%zu",
|
||||
rank, input, inputSize, output, outputSize, remoteOutputMemories.size());
|
||||
for (size_t i = 0; i < remoteOutputMemories.size(); ++i) {
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: remoteOutput[%zu] data=%p, size=%zu",
|
||||
rank, i, remoteOutputMemories[i].data(), remoteOutputMemories[i].size());
|
||||
}
|
||||
|
||||
// Build peer locality map and per-type channel arrays
|
||||
int nPeers = ctx->worldSize - 1;
|
||||
int thisNode = rank / ctx->nRanksPerNode;
|
||||
std::vector<int> peerIsLocal(nPeers, 0);
|
||||
std::vector<int> peerToPortChIdx(nPeers, -1);
|
||||
int portChCount = 0;
|
||||
for (int peerIdx = 0; peerIdx < nPeers; peerIdx++) {
|
||||
int peer = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
if (peer / ctx->nRanksPerNode == thisNode) {
|
||||
peerIsLocal[peerIdx] = 1;
|
||||
} else {
|
||||
peerToPortChIdx[peerIdx] = portChCount++;
|
||||
}
|
||||
}
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB hybrid: nPeers=%d, localPeers=%d, remotePeers=%d",
|
||||
rank, nPeers, nPeers - portChCount, portChCount);
|
||||
|
||||
// Copy locality arrays to GPU
|
||||
ctx->d_peerIsLocal = mscclpp::detail::gpuCallocShared<int>(nPeers);
|
||||
ctx->d_peerToPortChannelIdx = mscclpp::detail::gpuCallocShared<int>(nPeers);
|
||||
mscclpp::gpuMemcpy<int>(ctx->d_peerIsLocal.get(), peerIsLocal.data(), nPeers, cudaMemcpyHostToDevice);
|
||||
mscclpp::gpuMemcpy<int>(ctx->d_peerToPortChannelIdx.get(), peerToPortChIdx.data(), nPeers, cudaMemcpyHostToDevice);
|
||||
|
||||
// MemoryChannel for intra-node CudaIpc connections (direct NVLink put)
|
||||
constexpr int nChannelsPerConnection = 1;
|
||||
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);
|
||||
ctx->memoryChannels = setupMemoryChannels(
|
||||
this->conns_, ctx->memorySemaphores, remoteOutputMemories, inputBufRegMem, nChannelsPerConnection);
|
||||
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB hybrid: %zu memoryChannels (intra-node)",
|
||||
rank, ctx->memoryChannels.size());
|
||||
|
||||
// PortChannel for inter-node IB connections only (CPU proxy → RDMA)
|
||||
ctx->proxyService = std::make_shared<ProxyService>();
|
||||
ctx->portChannels = setupAllPortChannels(
|
||||
ctx->portChannels = setupPortChannels(
|
||||
ctx->proxyService, *comm, this->conns_, remoteOutputMemories, inputBufRegMem);
|
||||
ctx->portChannelDeviceHandles = setupPortChannelDeviceHandles(ctx->portChannels);
|
||||
ctx->proxyService->startProxy(true);
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB: %zu portChannels created, proxy started",
|
||||
INFO(MSCCLPP_COLL, "[alltoallv][rank %d] IB hybrid: %zu portChannels (inter-node), proxy started",
|
||||
rank, ctx->portChannels.size());
|
||||
|
||||
ctx->registeredMemories = std::move(remoteOutputMemories);
|
||||
|
||||
Reference in New Issue
Block a user