mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add templates for fp16 and unsigned short atomic add to fix FBGEMM builds. (#2471)
* add template for fp16 atomic add * add template for unsigned short atomic add * use atomicCAS in atomic add for fp16 and unsigned short
This commit is contained in:
@@ -32,6 +32,33 @@ __device__ float atomic_add<float>(float* p_dst, const float& x)
|
||||
return atomicAdd(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ unsigned short atomic_add<unsigned short>(unsigned short* p_dst, const unsigned short& x)
|
||||
{
|
||||
unsigned short old_val, new_val;
|
||||
do
|
||||
{
|
||||
old_val = *p_dst;
|
||||
new_val = old_val + x;
|
||||
} while(atomicCAS(p_dst, old_val, new_val) != old_val);
|
||||
return old_val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ _Float16 atomic_add<_Float16>(_Float16* p_dst, const _Float16& x)
|
||||
{
|
||||
_Float16 old_val, new_val;
|
||||
do
|
||||
{
|
||||
old_val = *p_dst;
|
||||
new_val = old_val + x; // Proper FP16 addition
|
||||
} while(atomicCAS(reinterpret_cast<unsigned short*>(p_dst),
|
||||
*reinterpret_cast<unsigned short*>(&old_val),
|
||||
*reinterpret_cast<unsigned short*>(&new_val)) !=
|
||||
*reinterpret_cast<unsigned short*>(&old_val));
|
||||
return old_val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ double atomic_add<double>(double* p_dst, const double& x)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user