From 5311ec608d2fd5fa7835b213f8dc714ca68ce8d2 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Tue, 9 Dec 2025 07:10:58 +0000 Subject: [PATCH] Similar for half(4|8)_t as well --- .../utility/generic_memory_space_atomic.hpp | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 3deedaa403..961a0e1d77 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -146,6 +146,64 @@ __device__ float8_t atomic_add(float8_t* p_dst, const float8_t& x) return vy.template AsType()[I0]; } + +template <> +__device__ half4_t atomic_add(half4_t* p_dst, const half4_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomic_add(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = atomic_add(c_style_pointer_cast(p_dst) + 1, + vx.template AsType()[I1]); + vy.template AsType()(I2) = atomic_add(c_style_pointer_cast(p_dst) + 2, + vx.template AsType()[I2]); + vy.template AsType()(I3) = atomic_add(c_style_pointer_cast(p_dst) + 3, + vx.template AsType()[I3]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ half8_t atomic_add(half8_t* p_dst, const half8_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomic_add(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = atomic_add(c_style_pointer_cast(p_dst) + 1, + vx.template AsType()[I1]); + vy.template AsType()(I2) = atomic_add(c_style_pointer_cast(p_dst) + 2, + vx.template AsType()[I2]); + vy.template AsType()(I3) = atomic_add(c_style_pointer_cast(p_dst) + 3, + vx.template AsType()[I3]); + vy.template AsType()(I4) = atomic_add(c_style_pointer_cast(p_dst) + 4, + vx.template AsType()[I4]); + vy.template AsType()(I5) = atomic_add(c_style_pointer_cast(p_dst) + 5, + vx.template AsType()[I5]); + vy.template AsType()(I6) = atomic_add(c_style_pointer_cast(p_dst) + 6, + vx.template AsType()[I6]); + vy.template AsType()(I7) = atomic_add(c_style_pointer_cast(p_dst) + 7, + vx.template AsType()[I7]); + + return vy.template AsType()[I0]; +} #endif // defined(__gfx11__) // Caution: DO NOT REMOVE