Simplify/improve barrier in AllReduce6 (#317)

Drop superfluous __threadfence_system()
This commit is contained in:
Roshan Dathathri
2024-06-23 14:08:59 -07:00
committed by GitHub
parent 34f4d9d006
commit 91550dab4c

View File

@@ -788,6 +788,24 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
// Barrier among all devices followed by a memory fence
// Should be called by all threads on all devices
// Assumes \p num_threads_per_block >= \p num_ranks
__forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int thread_id,
int block_id, int num_blocks, int num_ranks) {
// wait for every device
if (block_id == 0) {
// 1 less than the num_ranks because there is no semaphore for self
if (thread_id < num_ranks - 1) {
semaphores[thread_id].signal();
semaphores[thread_id].wait();
}
}
// wait for every thread in every block on this device
deviceSyncer.sync(num_blocks);
}
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, TYPE* buff, int my_rank, int nranks,
@@ -796,17 +814,12 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_blocks = gridDim.x;
if (tid == 0 && bid == 0) {
__threadfence_system();
}
if (bid == 0) {
if (tid < nranks - 1) {
semaphores[tid].signal();
semaphores[tid].wait();
}
}
deviceSyncer.sync(gridDim.x);
// start with a barrier to ensure all devices have written their values
// to their own memory (that is part of the multicast memory)
// before reading them in this kernel
barrier(semaphores, tid, bid, num_blocks, nranks);
int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
@@ -815,22 +828,14 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
int my_step = blockDim.x * gridDim.x * 4;
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val;
uint4 val; // fits 8 cutlass::half_t elements; i.e., 4 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
}
deviceSyncer.sync(gridDim.x);
if (tid == 0 && bid == 0) {
__threadfence_system();
}
if (bid == 0) {
if (tid < nranks - 1) {
semaphores[tid].signal();
semaphores[tid].wait();
}
}
deviceSyncer.sync(gridDim.x);
// end with a barrier to ensure all devices can now read their values
// from their own memory (that is part of the multicast memory)
// after writing them in this kernel
barrier(semaphores, tid, bid, num_blocks, nranks);
}
#endif