mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Similar for half(4|8)_t as well
This commit is contained in:
@@ -146,6 +146,64 @@ __device__ float8_t atomic_add<float8_t>(float8_t* p_dst, const float8_t& x)
|
||||
|
||||
return vy.template AsType<float8_t>()[I0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ half4_t atomic_add<half4_t>(half4_t* p_dst, const half4_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const vector_type<half_t, 4> vx{x};
|
||||
vector_type<half_t, 4> vy{0};
|
||||
|
||||
vy.template AsType<half_t>()(I0) =
|
||||
atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst), vx.template AsType<half_t>()[I0]);
|
||||
vy.template AsType<half_t>()(I1) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 1,
|
||||
vx.template AsType<half_t>()[I1]);
|
||||
vy.template AsType<half_t>()(I2) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 2,
|
||||
vx.template AsType<half_t>()[I2]);
|
||||
vy.template AsType<half_t>()(I3) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 3,
|
||||
vx.template AsType<half_t>()[I3]);
|
||||
|
||||
return vy.template AsType<half4_t>()[I0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ half8_t atomic_add<half8_t>(half8_t* p_dst, const half8_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
const vector_type<half_t, 8> vx{x};
|
||||
vector_type<half_t, 8> vy{0};
|
||||
|
||||
vy.template AsType<half_t>()(I0) =
|
||||
atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst), vx.template AsType<half_t>()[I0]);
|
||||
vy.template AsType<half_t>()(I1) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 1,
|
||||
vx.template AsType<half_t>()[I1]);
|
||||
vy.template AsType<half_t>()(I2) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 2,
|
||||
vx.template AsType<half_t>()[I2]);
|
||||
vy.template AsType<half_t>()(I3) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 3,
|
||||
vx.template AsType<half_t>()[I3]);
|
||||
vy.template AsType<half_t>()(I4) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 4,
|
||||
vx.template AsType<half_t>()[I4]);
|
||||
vy.template AsType<half_t>()(I5) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 5,
|
||||
vx.template AsType<half_t>()[I5]);
|
||||
vy.template AsType<half_t>()(I6) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 6,
|
||||
vx.template AsType<half_t>()[I6]);
|
||||
vy.template AsType<half_t>()(I7) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 7,
|
||||
vx.template AsType<half_t>()[I7]);
|
||||
|
||||
return vy.template AsType<half8_t>()[I0];
|
||||
}
|
||||
#endif // defined(__gfx11__)
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
|
||||
Reference in New Issue
Block a user