mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
@@ -798,13 +798,13 @@ __forceinline__ __device__ void barrier(mscclpp::MemoryDevice2DeviceSemaphoreDev
|
||||
// Assumes kVecSize is 1, 2, 4, or 8
|
||||
template <typename DataType, int kVecSize>
|
||||
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
||||
mscclpp::SwitchChannelDeviceHandle switchChan, size_t my_rank,
|
||||
size_t num_ranks, size_t num_elements) {
|
||||
mscclpp::SwitchChannelDeviceHandle switchChan, int my_rank, int num_ranks,
|
||||
size_t num_elements) {
|
||||
using VectorType = mscclpp::VectorType<DataType, kVecSize>;
|
||||
size_t tid = threadIdx.x;
|
||||
size_t bid = blockIdx.x;
|
||||
size_t num_threads_per_block = blockDim.x;
|
||||
size_t num_blocks = gridDim.x;
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
int num_threads_per_block = blockDim.x;
|
||||
int num_blocks = gridDim.x;
|
||||
|
||||
// start with a barrier to ensure all devices have written their values
|
||||
// to their own memory (that is part of the multicast memory)
|
||||
@@ -821,7 +821,7 @@ MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemapho
|
||||
|
||||
for (size_t idx = rank_start + thread_offset; idx < rank_end; idx += thread_step) {
|
||||
auto val = switchChan.reduce<VectorType>(idx);
|
||||
switchChan.store(idx, val);
|
||||
switchChan.broadcast(idx, val);
|
||||
}
|
||||
|
||||
// end with a barrier to ensure all devices can now read their values
|
||||
@@ -832,16 +832,31 @@ MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemapho
|
||||
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
allreduce6(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
|
||||
mscclpp::SwitchChannelDeviceHandle switchChan, size_t my_rank, size_t num_ranks, size_t num_elements,
|
||||
mscclpp::SwitchChannelDeviceHandle switchChan, int my_rank, int num_ranks, size_t num_elements,
|
||||
size_t vector_size) {
|
||||
if (vector_size == 8) {
|
||||
allreduce6_helper<TYPE, 8>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
} else if (vector_size == 4) {
|
||||
allreduce6_helper<TYPE, 4>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
} else if (vector_size == 2) {
|
||||
allreduce6_helper<TYPE, 2>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
if constexpr (sizeof(TYPE) == 4) {
|
||||
if (vector_size == 4) {
|
||||
allreduce6_helper<TYPE, 4>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
} else if (vector_size == 2) {
|
||||
allreduce6_helper<TYPE, 2>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
} else if (vector_size == 1) {
|
||||
allreduce6_helper<TYPE, 1>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
} else {
|
||||
assert(false && "Unsupported vector size for allreduce6.");
|
||||
}
|
||||
} else if constexpr (sizeof(TYPE) == 2) {
|
||||
if (vector_size == 8) {
|
||||
allreduce6_helper<TYPE, 8>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
} else if (vector_size == 4) {
|
||||
allreduce6_helper<TYPE, 4>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
} else if (vector_size == 2) {
|
||||
allreduce6_helper<TYPE, 2>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
} else {
|
||||
assert(false && "Unsupported vector size for allreduce6.");
|
||||
}
|
||||
} else {
|
||||
allreduce6_helper<TYPE, 1>(semaphores, switchChan, my_rank, num_ranks, num_elements);
|
||||
// unsupported vector size
|
||||
static_assert(sizeof(TYPE) == 4 || sizeof(TYPE) == 2, "Unsupported TYPE size for allreduce6.");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user