From 0db55aad9476db60c961da12ae882bdefaa2202c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 8 Jul 2025 19:01:26 -0700 Subject: [PATCH] =?UTF-8?q?Revert=20"Add=20templates=20for=20fp16=20and=20?= =?UTF-8?q?unsigned=20short=20atomic=20add=20to=20fix=20FBGEMM=20bu?= =?UTF-8?q?=E2=80=A6"=20(#2474)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit b00bbe4b355adf7b2778646d7c9c7633307836be. [ROCm/composable_kernel commit: 93420ecf89d0747c35b096aa95453eaaceb0aea3] --- .../utility/generic_memory_space_atomic.hpp | 27 ------------------ .../gpu/gemm_universal_preshuffle.inc | 28 +++++++++++++++---- 2 files changed, 22 insertions(+), 33 deletions(-) diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 3dda8af8e2..ab9cc4199c 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -32,33 +32,6 @@ __device__ float atomic_add(float* p_dst, const float& x) return atomicAdd(p_dst, x); } -template <> -__device__ unsigned short atomic_add(unsigned short* p_dst, const unsigned short& x) -{ - unsigned short old_val, new_val; - do - { - old_val = *p_dst; - new_val = old_val + x; - } while(atomicCAS(p_dst, old_val, new_val) != old_val); - return old_val; -} - -template <> -__device__ _Float16 atomic_add<_Float16>(_Float16* p_dst, const _Float16& x) -{ - _Float16 old_val, new_val; - do - { - old_val = *p_dst; - new_val = old_val + x; // Proper FP16 addition - } while(atomicCAS(reinterpret_cast(p_dst), - *reinterpret_cast(&old_val), - *reinterpret_cast(&new_val)) != - *reinterpret_cast(&old_val)); - return old_val; -} - template <> __device__ double atomic_add(double* p_dst, const double& x) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc index b987519082..b44d60deaf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc @@ -10,11 +10,27 @@ namespace instance { #if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) -using GemmF8F8BF16InstanceVector = std::vector>>&; +using GemmF8F8BF16InstanceVector = + std::vector>>&; -using GemmF8F8F16InstanceVector = std::vector>>&; +using GemmF8F8F16InstanceVector = + std::vector>>&; void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances( GemmF8F8BF16InstanceVector& instances); @@ -32,7 +48,7 @@ void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_instances GemmF8F8BF16InstanceVector& instances); void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_instances( - GemmF8F8BF16InstanceVector& instances); + GemmF8F8BF16InstanceVector& instances); void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_instances( GemmF8F8BF16InstanceVector& instances); @@ -68,7 +84,7 @@ void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_defau GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( - GemmF8F8F16InstanceVector& instances); + GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( GemmF8F8F16InstanceVector& instances); void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(