Add templates for fp16 and unsigned short atomic add to fix FBGEMM builds. (#2471)

* add template for fp16 atomic add

* add template for unsigned short atomic add

* use atomicCAS in atomic add for fp16 and unsigned short

[ROCm/composable_kernel commit: 112b47e885]
This commit is contained in:
Illia Silin
2025-07-08 15:09:30 -07:00
committed by GitHub
parent 89f226aace
commit 85af00c08c
2 changed files with 33 additions and 22 deletions

View File

@@ -32,6 +32,33 @@ __device__ float atomic_add<float>(float* p_dst, const float& x)
return atomicAdd(p_dst, x);
}
template <>
__device__ unsigned short atomic_add<unsigned short>(unsigned short* p_dst, const unsigned short& x)
{
unsigned short old_val, new_val;
do
{
old_val = *p_dst;
new_val = old_val + x;
} while(atomicCAS(p_dst, old_val, new_val) != old_val);
return old_val;
}
template <>
__device__ _Float16 atomic_add<_Float16>(_Float16* p_dst, const _Float16& x)
{
_Float16 old_val, new_val;
do
{
old_val = *p_dst;
new_val = old_val + x; // Proper FP16 addition
} while(atomicCAS(reinterpret_cast<unsigned short*>(p_dst),
*reinterpret_cast<unsigned short*>(&old_val),
*reinterpret_cast<unsigned short*>(&new_val)) !=
*reinterpret_cast<unsigned short*>(&old_val));
return old_val;
}
template <>
__device__ double atomic_add<double>(double* p_dst, const double& x)
{

View File

@@ -10,27 +10,11 @@ namespace instance {
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
using GemmF8F8BF16InstanceVector =
std::vector<std::unique_ptr<DeviceGemmV2BPreshuffle<Row,
Col,
Row,
F8,
F8,
BF16,
PassThrough,
PassThrough,
PassThrough>>>&;
using GemmF8F8BF16InstanceVector = std::vector<std::unique_ptr<
DeviceGemmV2BPreshuffle<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&;
using GemmF8F8F16InstanceVector =
std::vector<std::unique_ptr<DeviceGemmV2BPreshuffle<Row,
Col,
Row,
F8,
F8,
F16,
PassThrough,
PassThrough,
PassThrough>>>&;
using GemmF8F8F16InstanceVector = std::vector<std::unique_ptr<
DeviceGemmV2BPreshuffle<Row, Col, Row, F8, F8, F16, PassThrough, PassThrough, PassThrough>>>&;
void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances(
GemmF8F8BF16InstanceVector& instances);
@@ -48,7 +32,7 @@ void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_instances
GemmF8F8BF16InstanceVector& instances);
void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_instances(
GemmF8F8BF16InstanceVector& instances);
GemmF8F8BF16InstanceVector& instances);
void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_instances(
GemmF8F8BF16InstanceVector& instances);
@@ -84,7 +68,7 @@ void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_defau
GemmF8F8F16InstanceVector& instances);
void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
GemmF8F8F16InstanceVector& instances);
GemmF8F8F16InstanceVector& instances);
void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
GemmF8F8F16InstanceVector& instances);
void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(