Use maximum threads (1024) for best bandwidth utilization

This commit is contained in:
Qinghua Zhou
2026-02-18 03:00:29 +00:00
parent b7485762a5
commit 43980da455
3 changed files with 48 additions and 62 deletions

View File

@@ -71,7 +71,7 @@ void AlltoallvFullmesh::initialize(std::shared_ptr<Communicator> comm) {
CommResult AlltoallvFullmesh::alltoallvKernelFunc(
const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream,
[[maybe_unused]] int nBlocks, [[maybe_unused]] int nThreadsPerBlock,
[[maybe_unused]] int nBlocks, int nThreadsPerBlock,
const std::unordered_map<std::string, uintptr_t>& extras) {
auto algoCtx = std::static_pointer_cast<AllToAllVContext>(ctx);
@@ -94,14 +94,13 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
const size_t* d_recvCounts = reinterpret_cast<const size_t*>(it_recvCounts->second);
const size_t* d_recvDispls = reinterpret_cast<const size_t*>(it_recvDispls->second);
// Use maximum threads (1024) for best bandwidth utilization
const int threadsPerBlock = (nThreadsPerBlock > 0 && nThreadsPerBlock <= 1024) ? nThreadsPerBlock : 1024;
// Choose kernel based on world size
if (worldSize <= 16) {
// Use parallel warp-based kernel for small world sizes
int nThreads = (worldSize - 1) * ALLTOALLV_WARP_SIZE;
if (nThreads < 32) nThreads = 32;
if (nThreads > 1024) nThreads = 1024;
alltoallvKernel<<<1, nThreads, 0, stream>>>(
// Use high-throughput kernel with all threads
alltoallvKernel<<<1, threadsPerBlock, 0, stream>>>(
algoCtx->memoryChannelDeviceHandles.get(),
rank, worldSize,
input, output,
@@ -109,7 +108,7 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
d_recvCounts, d_recvDispls);
} else {
// Use ring-based kernel for larger world sizes
alltoallvRingKernel<<<1, 32, 0, stream>>>(
alltoallvRingKernel<<<1, threadsPerBlock, 0, stream>>>(
algoCtx->memoryChannelDeviceHandles.get(),
rank, worldSize,
input, output,

View File

@@ -17,10 +17,13 @@ namespace collective {
#endif
/**
* AllToAllV kernel implementation using parallel warp-based communication with MemoryChannel.
* High-performance AllToAllV kernel using maximum thread parallelism.
*
* Each warp handles communication with one peer. Data is copied in parallel using all threads
* in the warp, which significantly improves throughput for large messages.
* Processes each peer sequentially but uses ALL block threads (1024) for each
* data transfer to maximize copy bandwidth. This provides much better performance
* than the warp-per-peer approach for large message sizes.
*
* Launch config: <<<1, 1024>>> for maximum bandwidth within a single block.
*
* @param memoryChannels Array of MemoryChannel handles for each peer (worldSize-1 channels)
* @param rank Current rank
@@ -43,64 +46,54 @@ __global__ void __launch_bounds__(1024)
const size_t* recvCounts,
const size_t* recvDispls) {
int tid = threadIdx.x;
int nThreads = blockDim.x;
int nPeers = worldSize - 1;
// Step 1: Copy local data (rank's own portion) using all threads
// Step 1: Copy local data using ALL threads for maximum bandwidth
if (sendCounts[rank] > 0) {
mscclpp::copy((char*)recvBuff + recvDispls[rank],
(void*)((const char*)sendBuff + sendDispls[rank]),
sendCounts[rank], tid, blockDim.x);
sendCounts[rank], tid, nThreads);
}
__syncthreads();
// Step 2: Each warp handles one peer for sending (parallel copy within warp)
int warpId = tid / ALLTOALLV_WARP_SIZE;
int laneId = tid % ALLTOALLV_WARP_SIZE;
if (warpId < nPeers) {
// Determine which peer this warp handles
int peer = warpId < rank ? warpId : warpId + 1;
int chanIdx = warpId;
// Step 2: Process each peer sequentially, but use ALL threads for each transfer
// This maximizes bandwidth for each transfer compared to warp-per-peer approach
for (int peerIdx = 0; peerIdx < nPeers; peerIdx++) {
int peer = peerIdx < rank ? peerIdx : peerIdx + 1;
int chanIdx = peerIdx;
if (sendCounts[peer] > 0) {
// Use parallel put with all threads in the warp
// targetOffset: recvDispls[rank] - where peer should receive our data
// originOffset: sendDispls[peer] - where our data for this peer starts
// Use all threads for maximum copy throughput
memoryChannels[chanIdx].put(
recvDispls[rank], // dst offset in peer's buffer
sendDispls[peer], // src offset in our buffer
sendCounts[peer], // size
laneId, // thread id within warp
ALLTOALLV_WARP_SIZE // number of threads
tid, // thread id
nThreads // total threads
);
}
}
__syncthreads();
__syncthreads();
// Step 3: Signal completion to all peers
if (warpId < nPeers && laneId == 0) {
memoryChannels[warpId].signal();
}
__syncthreads();
// Step 4: Wait for all incoming data
if (warpId < nPeers && laneId == 0) {
int peer = warpId < rank ? warpId : warpId + 1;
if (recvCounts[peer] > 0) {
memoryChannels[warpId].wait();
// Only one thread signals per peer
if (tid == 0) {
memoryChannels[chanIdx].signal();
}
__syncthreads();
// Wait for incoming data from this peer
if (tid == 0 && recvCounts[peer] > 0) {
memoryChannels[chanIdx].wait();
}
__syncthreads();
}
__syncthreads();
}
/**
* Ring-based AllToAllV kernel for serialized communication with MemoryChannel.
* Ring-based AllToAllV kernel with maximum thread parallelism.
*
* Uses step-by-step ring pattern to exchange data, sending to (rank+step) and
* receiving from (rank-step) in each step. All threads participate in the copy
* for better throughput.
*
* This kernel is more robust for larger world sizes.
* Uses step-by-step ring pattern with ALL threads for maximum bandwidth.
* Better for larger world sizes to avoid congestion.
*/
__global__ void __launch_bounds__(1024)
alltoallvRingKernel(DeviceHandle<MemoryChannel>* memoryChannels,
@@ -113,12 +106,13 @@ __global__ void __launch_bounds__(1024)
const size_t* recvCounts,
const size_t* recvDispls) {
int tid = threadIdx.x;
int nThreads = blockDim.x;
// Copy local data first using all threads
// Copy local data first using ALL threads
if (sendCounts[rank] > 0) {
mscclpp::copy((char*)recvBuff + recvDispls[rank],
(void*)((const char*)sendBuff + sendDispls[rank]),
sendCounts[rank], tid, blockDim.x);
sendCounts[rank], tid, nThreads);
}
__syncthreads();
@@ -130,14 +124,14 @@ __global__ void __launch_bounds__(1024)
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
// Send data to sendPeer using all threads
// Send data to sendPeer using ALL threads
if (sendCounts[sendPeer] > 0) {
memoryChannels[sendChanIdx].put(
recvDispls[rank],
sendDispls[sendPeer],
sendCounts[sendPeer],
tid,
blockDim.x
nThreads
);
}
__syncthreads();

View File

@@ -57,18 +57,11 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
const int rank = args.rank;
const int kernelNum = args.kernelNum;
// Reset device syncer
mscclpp::DeviceSyncer syncer = {};
CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer)));
// Use maximum threads (1024) for best bandwidth utilization
const int nThreads = 1024;
if (kernelNum == 0) {
// Use parallel warp-based kernel from library
int nThreads = (worldSize - 1) * 32; // One warp per peer
#if defined(__HIP_PLATFORM_AMD__)
nThreads = (worldSize - 1) * 64;
#endif
if (nThreads < 32) nThreads = 32;
if (nThreads > 1024) nThreads = 1024;
// Use high-throughput kernel with all threads participating in each transfer
mscclpp::collective::alltoallvKernel<<<1, nThreads, 0, stream>>>(
d_memoryChannels,
rank, worldSize,
@@ -76,8 +69,8 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
d_sendCounts, d_sendDispls,
d_recvCounts, d_recvDispls);
} else if (kernelNum == 1) {
// Use ring-based kernel from library
mscclpp::collective::alltoallvRingKernel<<<1, 32, 0, stream>>>(
// Use ring-based kernel for larger world sizes
mscclpp::collective::alltoallvRingKernel<<<1, nThreads, 0, stream>>>(
d_memoryChannels,
rank, worldSize,
localSendBuffV, localRecvBuffV,