From e61ceee502e4c3bfbd4a1d90a8ec7db83961f3ca Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 10 Jul 2025 07:18:56 -0700 Subject: [PATCH] Add declarations for atomic add for fp16 and unsigned short. (#2483) * add template for fp16 atomic add * add template for unsigned short atomic add * use atomicCAS in atomic add for fp16 and unsigned short * revrt back to atomic add using casting [ROCm/composable_kernel commit: 1b66f3f4a32f1e755e8ac70a16e879f4f6523870] --- .../utility/generic_memory_space_atomic.hpp | 16 +++++++++++ .../gpu/gemm_universal_preshuffle.inc | 28 ++++--------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index ab9cc4199c..011491ffc6 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -32,6 +32,22 @@ __device__ float atomic_add(float* p_dst, const float& x) return atomicAdd(p_dst, x); } +template <> +__device__ unsigned short atomic_add(unsigned short* p_dst, const unsigned short& x) +{ + // Use atomicAdd with unsigned int + return static_cast( + atomicAdd(reinterpret_cast(p_dst), static_cast(x))); +} + +template <> +__device__ _Float16 atomic_add<_Float16>(_Float16* p_dst, const _Float16& x) +{ + // Use atomicAdd with unsigned int + return static_cast<_Float16>( + atomicAdd(reinterpret_cast(p_dst), static_cast(x))); +} + template <> __device__ double atomic_add(double* p_dst, const double& x) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc index b44d60deaf..b987519082 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc @@ -10,27 +10,11 @@ namespace instance { #if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) -using GemmF8F8BF16InstanceVector = - std::vector>>&; +using GemmF8F8BF16InstanceVector = std::vector>>&; -using GemmF8F8F16InstanceVector = - std::vector>>&; +using GemmF8F8F16InstanceVector = std::vector>>&; void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances( GemmF8F8BF16InstanceVector& instances); @@ -48,7 +32,7 @@ void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_instances GemmF8F8BF16InstanceVector& instances); void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_instances( - GemmF8F8BF16InstanceVector& instances); + GemmF8F8BF16InstanceVector& instances); void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_instances( GemmF8F8BF16InstanceVector& instances); @@ -84,7 +68,7 @@ void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_defau GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( - GemmF8F8F16InstanceVector& instances); + GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(