mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Simplify/improve barrier in AllReduce6 (#317)
Drop superfluous __threadfence_system()
This commit is contained in:
@@ -788,6 +788,24 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
|
||||
|
||||
// Barrier among all devices followed by a memory fence
|
||||
// Should be called by all threads on all devices
|
||||
// Assumes \p num_threads_per_block >= \p num_ranks
|
||||
__forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int thread_id,
|
||||
int block_id, int num_blocks, int num_ranks) {
|
||||
// wait for every device
|
||||
if (block_id == 0) {
|
||||
// 1 less than the num_ranks because there is no semaphore for self
|
||||
if (thread_id < num_ranks - 1) {
|
||||
semaphores[thread_id].signal();
|
||||
semaphores[thread_id].wait();
|
||||
}
|
||||
}
|
||||
|
||||
// wait for every thread in every block on this device
|
||||
deviceSyncer.sync(num_blocks);
|
||||
}
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, TYPE* buff, int my_rank, int nranks,
|
||||
@@ -796,17 +814,12 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
int num_blocks = gridDim.x;
|
||||
|
||||
if (tid == 0 && bid == 0) {
|
||||
__threadfence_system();
|
||||
}
|
||||
if (bid == 0) {
|
||||
if (tid < nranks - 1) {
|
||||
semaphores[tid].signal();
|
||||
semaphores[tid].wait();
|
||||
}
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
// start with a barrier to ensure all devices have written their values
|
||||
// to their own memory (that is part of the multicast memory)
|
||||
// before reading them in this kernel
|
||||
barrier(semaphores, tid, bid, num_blocks, nranks);
|
||||
|
||||
int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
|
||||
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
|
||||
@@ -815,22 +828,14 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
int my_step = blockDim.x * gridDim.x * 4;
|
||||
|
||||
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
|
||||
uint4 val;
|
||||
uint4 val; // fits 8 cutlass::half_t elements; i.e., 4 half2 elements
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
|
||||
}
|
||||
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
if (tid == 0 && bid == 0) {
|
||||
__threadfence_system();
|
||||
}
|
||||
|
||||
if (bid == 0) {
|
||||
if (tid < nranks - 1) {
|
||||
semaphores[tid].signal();
|
||||
semaphores[tid].wait();
|
||||
}
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
// end with a barrier to ensure all devices can now read their values
|
||||
// from their own memory (that is part of the multicast memory)
|
||||
// after writing them in this kernel
|
||||
barrier(semaphores, tid, bid, num_blocks, nranks);
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user