mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Use maximum threads (1024) for best bandwidth utilization
This commit is contained in:
@@ -71,7 +71,7 @@ void AlltoallvFullmesh::initialize(std::shared_ptr<Communicator> comm) {
|
||||
CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream,
|
||||
[[maybe_unused]] int nBlocks, [[maybe_unused]] int nThreadsPerBlock,
|
||||
[[maybe_unused]] int nBlocks, int nThreadsPerBlock,
|
||||
const std::unordered_map<std::string, uintptr_t>& extras) {
|
||||
|
||||
auto algoCtx = std::static_pointer_cast<AllToAllVContext>(ctx);
|
||||
@@ -94,14 +94,13 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
const size_t* d_recvCounts = reinterpret_cast<const size_t*>(it_recvCounts->second);
|
||||
const size_t* d_recvDispls = reinterpret_cast<const size_t*>(it_recvDispls->second);
|
||||
|
||||
// Use maximum threads (1024) for best bandwidth utilization
|
||||
const int threadsPerBlock = (nThreadsPerBlock > 0 && nThreadsPerBlock <= 1024) ? nThreadsPerBlock : 1024;
|
||||
|
||||
// Choose kernel based on world size
|
||||
if (worldSize <= 16) {
|
||||
// Use parallel warp-based kernel for small world sizes
|
||||
int nThreads = (worldSize - 1) * ALLTOALLV_WARP_SIZE;
|
||||
if (nThreads < 32) nThreads = 32;
|
||||
if (nThreads > 1024) nThreads = 1024;
|
||||
|
||||
alltoallvKernel<<<1, nThreads, 0, stream>>>(
|
||||
// Use high-throughput kernel with all threads
|
||||
alltoallvKernel<<<1, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
@@ -109,7 +108,7 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc(
|
||||
d_recvCounts, d_recvDispls);
|
||||
} else {
|
||||
// Use ring-based kernel for larger world sizes
|
||||
alltoallvRingKernel<<<1, 32, 0, stream>>>(
|
||||
alltoallvRingKernel<<<1, threadsPerBlock, 0, stream>>>(
|
||||
algoCtx->memoryChannelDeviceHandles.get(),
|
||||
rank, worldSize,
|
||||
input, output,
|
||||
|
||||
@@ -17,10 +17,13 @@ namespace collective {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* AllToAllV kernel implementation using parallel warp-based communication with MemoryChannel.
|
||||
* High-performance AllToAllV kernel using maximum thread parallelism.
|
||||
*
|
||||
* Each warp handles communication with one peer. Data is copied in parallel using all threads
|
||||
* in the warp, which significantly improves throughput for large messages.
|
||||
* Processes each peer sequentially but uses ALL block threads (1024) for each
|
||||
* data transfer to maximize copy bandwidth. This provides much better performance
|
||||
* than the warp-per-peer approach for large message sizes.
|
||||
*
|
||||
* Launch config: <<<1, 1024>>> for maximum bandwidth within a single block.
|
||||
*
|
||||
* @param memoryChannels Array of MemoryChannel handles for each peer (worldSize-1 channels)
|
||||
* @param rank Current rank
|
||||
@@ -43,64 +46,54 @@ __global__ void __launch_bounds__(1024)
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls) {
|
||||
int tid = threadIdx.x;
|
||||
int nThreads = blockDim.x;
|
||||
int nPeers = worldSize - 1;
|
||||
|
||||
// Step 1: Copy local data (rank's own portion) using all threads
|
||||
// Step 1: Copy local data using ALL threads for maximum bandwidth
|
||||
if (sendCounts[rank] > 0) {
|
||||
mscclpp::copy((char*)recvBuff + recvDispls[rank],
|
||||
(void*)((const char*)sendBuff + sendDispls[rank]),
|
||||
sendCounts[rank], tid, blockDim.x);
|
||||
sendCounts[rank], tid, nThreads);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 2: Each warp handles one peer for sending (parallel copy within warp)
|
||||
int warpId = tid / ALLTOALLV_WARP_SIZE;
|
||||
int laneId = tid % ALLTOALLV_WARP_SIZE;
|
||||
|
||||
if (warpId < nPeers) {
|
||||
// Determine which peer this warp handles
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
int chanIdx = warpId;
|
||||
// Step 2: Process each peer sequentially, but use ALL threads for each transfer
|
||||
// This maximizes bandwidth for each transfer compared to warp-per-peer approach
|
||||
for (int peerIdx = 0; peerIdx < nPeers; peerIdx++) {
|
||||
int peer = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
int chanIdx = peerIdx;
|
||||
|
||||
if (sendCounts[peer] > 0) {
|
||||
// Use parallel put with all threads in the warp
|
||||
// targetOffset: recvDispls[rank] - where peer should receive our data
|
||||
// originOffset: sendDispls[peer] - where our data for this peer starts
|
||||
// Use all threads for maximum copy throughput
|
||||
memoryChannels[chanIdx].put(
|
||||
recvDispls[rank], // dst offset in peer's buffer
|
||||
sendDispls[peer], // src offset in our buffer
|
||||
sendCounts[peer], // size
|
||||
laneId, // thread id within warp
|
||||
ALLTOALLV_WARP_SIZE // number of threads
|
||||
tid, // thread id
|
||||
nThreads // total threads
|
||||
);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
|
||||
// Step 3: Signal completion to all peers
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
memoryChannels[warpId].signal();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 4: Wait for all incoming data
|
||||
if (warpId < nPeers && laneId == 0) {
|
||||
int peer = warpId < rank ? warpId : warpId + 1;
|
||||
if (recvCounts[peer] > 0) {
|
||||
memoryChannels[warpId].wait();
|
||||
// Only one thread signals per peer
|
||||
if (tid == 0) {
|
||||
memoryChannels[chanIdx].signal();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Wait for incoming data from this peer
|
||||
if (tid == 0 && recvCounts[peer] > 0) {
|
||||
memoryChannels[chanIdx].wait();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/**
|
||||
* Ring-based AllToAllV kernel for serialized communication with MemoryChannel.
|
||||
* Ring-based AllToAllV kernel with maximum thread parallelism.
|
||||
*
|
||||
* Uses step-by-step ring pattern to exchange data, sending to (rank+step) and
|
||||
* receiving from (rank-step) in each step. All threads participate in the copy
|
||||
* for better throughput.
|
||||
*
|
||||
* This kernel is more robust for larger world sizes.
|
||||
* Uses step-by-step ring pattern with ALL threads for maximum bandwidth.
|
||||
* Better for larger world sizes to avoid congestion.
|
||||
*/
|
||||
__global__ void __launch_bounds__(1024)
|
||||
alltoallvRingKernel(DeviceHandle<MemoryChannel>* memoryChannels,
|
||||
@@ -113,12 +106,13 @@ __global__ void __launch_bounds__(1024)
|
||||
const size_t* recvCounts,
|
||||
const size_t* recvDispls) {
|
||||
int tid = threadIdx.x;
|
||||
int nThreads = blockDim.x;
|
||||
|
||||
// Copy local data first using all threads
|
||||
// Copy local data first using ALL threads
|
||||
if (sendCounts[rank] > 0) {
|
||||
mscclpp::copy((char*)recvBuff + recvDispls[rank],
|
||||
(void*)((const char*)sendBuff + sendDispls[rank]),
|
||||
sendCounts[rank], tid, blockDim.x);
|
||||
sendCounts[rank], tid, nThreads);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -130,14 +124,14 @@ __global__ void __launch_bounds__(1024)
|
||||
int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1;
|
||||
int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1;
|
||||
|
||||
// Send data to sendPeer using all threads
|
||||
// Send data to sendPeer using ALL threads
|
||||
if (sendCounts[sendPeer] > 0) {
|
||||
memoryChannels[sendChanIdx].put(
|
||||
recvDispls[rank],
|
||||
sendDispls[sendPeer],
|
||||
sendCounts[sendPeer],
|
||||
tid,
|
||||
blockDim.x
|
||||
nThreads
|
||||
);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -57,18 +57,11 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
|
||||
const int rank = args.rank;
|
||||
const int kernelNum = args.kernelNum;
|
||||
|
||||
// Reset device syncer
|
||||
mscclpp::DeviceSyncer syncer = {};
|
||||
CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer)));
|
||||
// Use maximum threads (1024) for best bandwidth utilization
|
||||
const int nThreads = 1024;
|
||||
|
||||
if (kernelNum == 0) {
|
||||
// Use parallel warp-based kernel from library
|
||||
int nThreads = (worldSize - 1) * 32; // One warp per peer
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
nThreads = (worldSize - 1) * 64;
|
||||
#endif
|
||||
if (nThreads < 32) nThreads = 32;
|
||||
if (nThreads > 1024) nThreads = 1024;
|
||||
// Use high-throughput kernel with all threads participating in each transfer
|
||||
mscclpp::collective::alltoallvKernel<<<1, nThreads, 0, stream>>>(
|
||||
d_memoryChannels,
|
||||
rank, worldSize,
|
||||
@@ -76,8 +69,8 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) {
|
||||
d_sendCounts, d_sendDispls,
|
||||
d_recvCounts, d_recvDispls);
|
||||
} else if (kernelNum == 1) {
|
||||
// Use ring-based kernel from library
|
||||
mscclpp::collective::alltoallvRingKernel<<<1, 32, 0, stream>>>(
|
||||
// Use ring-based kernel for larger world sizes
|
||||
mscclpp::collective::alltoallvRingKernel<<<1, nThreads, 0, stream>>>(
|
||||
d_memoryChannels,
|
||||
rank, worldSize,
|
||||
localSendBuffV, localRecvBuffV,
|
||||
|
||||
Reference in New Issue
Block a user