From a783028023138f6dfedf075a59db765fdff0b54e Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 12 Aug 2025 19:32:46 +0000 Subject: [PATCH] Add atomic add float4 --- .../grouped_conv_bwd_weight_v3_wmma_bf16.cpp | 2 +- .../utility/generic_memory_space_atomic.hpp | 25 ++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_bf16.cpp index 6e6b877b16..47066889ad 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_bf16.cpp @@ -65,7 +65,7 @@ using DeviceConvBwdWeightInstance = 1, // CShuffleMRepeatPerShuffle 1, // CShuffleNRepeatPerShuffle S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 2>; // CShuffleBlockTransferScalarPerVector_NPerBlock + 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CShuffleBlockTransferScalarPerVector_NPerBlock template using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight(float2_t* p_dst, const float2_t& x) return vy.template AsType()[I0]; } +template <> +__device__ float4_t atomic_add(float4_t* p_dst, const float4_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + vy.template AsType()(I2) = + atomicAdd(c_style_pointer_cast(p_dst) + 2, vx.template AsType()[I2]); + vy.template AsType()(I3) = + atomicAdd(c_style_pointer_cast(p_dst) + 3, vx.template AsType()[I3]); + + return vy.template AsType()[I0]; +} + template <> __device__ double2_t atomic_add(double2_t* p_dst, const double2_t& x) {