mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Add support for multicast reduce insruction (#316)
This commit is contained in:
@@ -26,7 +26,7 @@ struct DeviceMulticastPointerDeviceHandle {
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
template <int NElemPerThread = 4, typename TValue = float4, typename T = float>
|
||||
MSCCLPP_DEVICE_INLINE static void multimemLoad(TValue& val, T* ptr) {
|
||||
MSCCLPP_DEVICE_INLINE static void multimemLoadReduce(TValue& val, T* ptr) {
|
||||
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
|
||||
@@ -58,6 +58,22 @@ struct DeviceMulticastPointerDeviceHandle {
|
||||
static_assert(dependentFalse<T>, "Not supported type");
|
||||
}
|
||||
};
|
||||
|
||||
template <int NElemPerThread = 4, typename TValue, typename T>
|
||||
MSCCLPP_DEVICE_INLINE static void multimemStoreReduce(const TValue& val, T* ptr) {
|
||||
static_assert(NElemPerThread == 4, "Only support NElemPerThread == 4");
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
asm volatile("multimem.red.relaxed.sys.global.add.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
|
||||
"r"(val.z), "r"(val.w)
|
||||
: "memory");
|
||||
} else if constexpr (std::is_same<T, half2>::value) {
|
||||
asm volatile("multimem.red.relaxed.sys.global.add.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x),
|
||||
"r"(val.y), "r"(val.z), "r"(val.w)
|
||||
: "memory");
|
||||
} else {
|
||||
static_assert(dependentFalse<T>, "Not supported type");
|
||||
}
|
||||
};
|
||||
#endif // defined(MSCCLPP_DEVICE_CUDA)
|
||||
};
|
||||
|
||||
|
||||
@@ -816,8 +816,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
|
||||
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
|
||||
uint4 val;
|
||||
DeviceMulticastPointerDeviceHandle::multimemLoad(val, mc_ptr + idx);
|
||||
DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
|
||||
}
|
||||
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
@@ -41,8 +41,8 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
|
||||
|
||||
for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
|
||||
uint4 val;
|
||||
DeviceMulticastPointerDeviceHandle::multimemLoad(val, mc_ptr + idx);
|
||||
DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
|
||||
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
|
||||
}
|
||||
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
|
||||
Reference in New Issue
Block a user