Similar for half(4|8)_t as well

This commit is contained in:
Graner, Johannes
2025-12-09 07:10:58 +00:00
parent 3883fe6935
commit 5311ec608d

View File

@@ -146,6 +146,64 @@ __device__ float8_t atomic_add<float8_t>(float8_t* p_dst, const float8_t& x)
return vy.template AsType<float8_t>()[I0];
}
template <>
__device__ half4_t atomic_add<half4_t>(half4_t* p_dst, const half4_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const vector_type<half_t, 4> vx{x};
vector_type<half_t, 4> vy{0};
vy.template AsType<half_t>()(I0) =
atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst), vx.template AsType<half_t>()[I0]);
vy.template AsType<half_t>()(I1) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 1,
vx.template AsType<half_t>()[I1]);
vy.template AsType<half_t>()(I2) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 2,
vx.template AsType<half_t>()[I2]);
vy.template AsType<half_t>()(I3) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 3,
vx.template AsType<half_t>()[I3]);
return vy.template AsType<half4_t>()[I0];
}
template <>
__device__ half8_t atomic_add<half8_t>(half8_t* p_dst, const half8_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
const vector_type<half_t, 8> vx{x};
vector_type<half_t, 8> vy{0};
vy.template AsType<half_t>()(I0) =
atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst), vx.template AsType<half_t>()[I0]);
vy.template AsType<half_t>()(I1) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 1,
vx.template AsType<half_t>()[I1]);
vy.template AsType<half_t>()(I2) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 2,
vx.template AsType<half_t>()[I2]);
vy.template AsType<half_t>()(I3) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 3,
vx.template AsType<half_t>()[I3]);
vy.template AsType<half_t>()(I4) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 4,
vx.template AsType<half_t>()[I4]);
vy.template AsType<half_t>()(I5) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 5,
vx.template AsType<half_t>()[I5]);
vy.template AsType<half_t>()(I6) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 6,
vx.template AsType<half_t>()[I6]);
vy.template AsType<half_t>()(I7) = atomic_add<half_t>(c_style_pointer_cast<half_t*>(p_dst) + 7,
vx.template AsType<half_t>()[I7]);
return vy.template AsType<half8_t>()[I0];
}
#endif // defined(__gfx11__)
// Caution: DO NOT REMOVE