Add support for multicast reduce insruction (#316)

This commit is contained in:
Roshan Dathathri
2024-06-19 13:28:12 -07:00
committed by GitHub
parent 1351f9f1c5
commit 93ed8e1e58
3 changed files with 21 additions and 5 deletions

View File

@@ -26,7 +26,7 @@ struct DeviceMulticastPointerDeviceHandle {
#if defined(MSCCLPP_DEVICE_CUDA)
template <int NElemPerThread = 4, typename TValue = float4, typename T = float>
MSCCLPP_DEVICE_INLINE static void multimemLoad(TValue& val, T* ptr) {
MSCCLPP_DEVICE_INLINE static void multimemLoadReduce(TValue& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
if constexpr (std::is_same<T, float>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
@@ -58,6 +58,22 @@ struct DeviceMulticastPointerDeviceHandle {
static_assert(dependentFalse<T>, "Not supported type");
}
};
template <int NElemPerThread = 4, typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStoreReduce(const TValue& val, T* ptr) {
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
if constexpr (std::is_same<T, float>::value) {
asm volatile("multimem.red.relaxed.sys.global.add.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<T, half2>::value) {
asm volatile("multimem.red.relaxed.sys.global.add.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x),
"r"(val.y), "r"(val.z), "r"(val.w)
: "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
};
#endif // defined(MSCCLPP_DEVICE_CUDA)
};

View File

@@ -816,8 +816,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val;
DeviceMulticastPointerDeviceHandle::multimemLoad(val, mc_ptr + idx);
DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
}
deviceSyncer.sync(gridDim.x);

View File

@@ -41,8 +41,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val;
DeviceMulticastPointerDeviceHandle::multimemLoad(val, mc_ptr + idx);
DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
}
deviceSyncer.sync(gridDim.x);