Fix mscclpp_benchmark allreduce test
This commit is contained in:
Binyang Li
2025-07-13 18:02:56 -07:00
committed by GitHub
parent 5e991cf5c8
commit 604c345921

View File

@@ -798,13 +798,13 @@ __forceinline__ __device__ void barrier(mscclpp::MemoryDevice2DeviceSemaphoreDev
// Assumes kVecSize is 1, 2, 4, or 8
template <typename DataType, int kVecSize>
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::SwitchChannelDeviceHandle switchChan, size_t my_rank,
size_t num_ranks, size_t num_elements) {
mscclpp::SwitchChannelDeviceHandle switchChan, int my_rank, int num_ranks,
size_t num_elements) {
using VectorType = mscclpp::VectorType<DataType, kVecSize>;
size_t tid = threadIdx.x;
size_t bid = blockIdx.x;
size_t num_threads_per_block = blockDim.x;
size_t num_blocks = gridDim.x;
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_threads_per_block = blockDim.x;
int num_blocks = gridDim.x;
// start with a barrier to ensure all devices have written their values
// to their own memory (that is part of the multicast memory)
@@ -821,7 +821,7 @@ MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemapho
for (size_t idx = rank_start + thread_offset; idx < rank_end; idx += thread_step) {
auto val = switchChan.reduce<VectorType>(idx);
switchChan.store(idx, val);
switchChan.broadcast(idx, val);
}
// end with a barrier to ensure all devices can now read their values
@@ -832,16 +832,31 @@ MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::MemoryDevice2DeviceSemapho
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::MemoryDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::SwitchChannelDeviceHandle switchChan, size_t my_rank, size_t num_ranks, size_t num_elements,
mscclpp::SwitchChannelDeviceHandle switchChan, int my_rank, int num_ranks, size_t num_elements,
size_t vector_size) {
if (vector_size == 8) {
allreduce6_helper<TYPE, 8>(semaphores, switchChan, my_rank, num_ranks, num_elements);
} else if (vector_size == 4) {
allreduce6_helper<TYPE, 4>(semaphores, switchChan, my_rank, num_ranks, num_elements);
} else if (vector_size == 2) {
allreduce6_helper<TYPE, 2>(semaphores, switchChan, my_rank, num_ranks, num_elements);
if constexpr (sizeof(TYPE) == 4) {
if (vector_size == 4) {
allreduce6_helper<TYPE, 4>(semaphores, switchChan, my_rank, num_ranks, num_elements);
} else if (vector_size == 2) {
allreduce6_helper<TYPE, 2>(semaphores, switchChan, my_rank, num_ranks, num_elements);
} else if (vector_size == 1) {
allreduce6_helper<TYPE, 1>(semaphores, switchChan, my_rank, num_ranks, num_elements);
} else {
assert(false && "Unsupported vector size for allreduce6.");
}
} else if constexpr (sizeof(TYPE) == 2) {
if (vector_size == 8) {
allreduce6_helper<TYPE, 8>(semaphores, switchChan, my_rank, num_ranks, num_elements);
} else if (vector_size == 4) {
allreduce6_helper<TYPE, 4>(semaphores, switchChan, my_rank, num_ranks, num_elements);
} else if (vector_size == 2) {
allreduce6_helper<TYPE, 2>(semaphores, switchChan, my_rank, num_ranks, num_elements);
} else {
assert(false && "Unsupported vector size for allreduce6.");
}
} else {
allreduce6_helper<TYPE, 1>(semaphores, switchChan, my_rank, num_ranks, num_elements);
// unsupported vector size
static_assert(sizeof(TYPE) == 4 || sizeof(TYPE) == 2, "Unsupported TYPE size for allreduce6.");
}
}
#endif