From d9316dfbeb3d7b257a7defb2592aee60ae2bfcc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 17 Jun 2025 22:25:56 +0200 Subject: [PATCH] Fix Add in dynamic buffer for fp32/i8 (#2351) * Fix Add in dynamic buffer for fp32/i8 * fixes * Fix [ROCm/composable_kernel commit: cc98a41f465108af2ecf5168c7bd7844a64b6fc5] --- .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 6 +-- include/ck/utility/dynamic_buffer.hpp | 52 ++----------------- 2 files changed, 7 insertions(+), 51 deletions(-) mode change 100755 => 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp mode change 100755 => 100644 include/ck/utility/dynamic_buffer.hpp diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp old mode 100755 new mode 100644 index f1c0ec1c68..d45ed79ae3 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1841,7 +1841,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, CShuffleDataType, // typename SrcData, - CShuffleDataType, // typename DstData, + AccDataType, // typename DstData, decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), Sequence<0, 1, 2, 3>, // typename DimAccessOrder, @@ -2591,7 +2591,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, CShuffleDataType, // typename SrcData, - CShuffleDataType, // typename DstData, + AccDataType, // typename DstData, decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), Sequence<0, 1, 2, 3>, // typename DimAccessOrder, diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp old mode 100755 new mode 100644 index eb35c34498..2debd09c2d --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -139,8 +139,7 @@ struct DynamicBuffer template >::type, - typename scalar_type>::type>::value || - !is_native_type(), + typename scalar_type>::type>::value, bool>::type = false> __host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x) { @@ -160,37 +159,7 @@ struct DynamicBuffer { auto tmp = this->template Get(i, is_valid_element); using scalar_t = typename scalar_type>::type; - -#if defined(__gfx942__) || defined(__gfx950__) - - // Properly handle addition for all low-precision types - if constexpr(is_same_v || is_same_v) - { - if constexpr(is_scalar_type::value) - { - // Scalar type: Convert to float, add, convert back - auto result = - type_convert(type_convert(x) + type_convert(tmp)); - this->template Set(i, is_valid_element, result); - } - else - { - // Vector type - constexpr auto vector_size = scalar_type>::vector_size; - const vector_type a_vector{tmp}; - const vector_type b_vector{x}; - - // Process each element of the vector in higher precision - static_for<0, vector_size, 1>{}([&](auto idx) { - auto result = type_convert( - type_convert(a_vector.template AsType()[idx]) + - type_convert(b_vector.template AsType()[idx])); - this->template Set(i + idx, is_valid_element, result); - }); - } - } -#else - // handle bfloat addition + // handle bfloat addition if constexpr(is_same_v) { if constexpr(is_scalar_type::value) @@ -218,8 +187,6 @@ struct DynamicBuffer { this->template Set(i, is_valid_element, x + tmp); } - -#endif } } @@ -273,20 +240,9 @@ struct DynamicBuffer if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - using vector_t = typename vector_type_maker, t_per_x>::type::type; - vector_t tmp; - - if constexpr(is_same_v, vector_t>) - { - tmp = x; - } - else - { - __builtin_memcpy(&tmp, &x, sizeof(vector_t)); - } amd_buffer_store, t_per_x, coherence>( - tmp, p_data_, i, is_valid_element, element_space_size_ / PackedSize); + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && is_same>::type, int8_t>::value &&