Fix Add in dynamic buffer for fp32/i8 (#2351)

* Fix Add in dynamic buffer for fp32/i8

* fixes

* Fix

[ROCm/composable_kernel commit: cc98a41f46]
This commit is contained in:
Bartłomiej Kocot
2025-06-17 22:25:56 +02:00
committed by GitHub
parent a4517b0a9d
commit f0d44c77d7
2 changed files with 7 additions and 51 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -1841,7 +1841,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
AccDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
@@ -2591,7 +2591,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CShuffleDataType, // typename DstData,
AccDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,

52
include/ck/utility/dynamic_buffer.hpp Executable file → Normal file
View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -139,8 +139,7 @@ struct DynamicBuffer
template <InMemoryDataOperationEnum Op,
typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x)
{
@@ -160,37 +159,7 @@ struct DynamicBuffer
{
auto tmp = this->template Get<X>(i, is_valid_element);
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
#if defined(__gfx942__) || defined(__gfx950__)
// Properly handle addition for all low-precision types
if constexpr(is_same_v<scalar_t, bhalf_t> || is_same_v<scalar_t, half_t>)
{
if constexpr(is_scalar_type<X>::value)
{
// Scalar type: Convert to float, add, convert back
auto result =
type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
this->template Set<X>(i, is_valid_element, result);
}
else
{
// Vector type
constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
const vector_type<scalar_t, vector_size> a_vector{tmp};
const vector_type<scalar_t, vector_size> b_vector{x};
// Process each element of the vector in higher precision
static_for<0, vector_size, 1>{}([&](auto idx) {
auto result = type_convert<scalar_t>(
type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
this->template Set<scalar_t>(i + idx, is_valid_element, result);
});
}
}
#else
// handle bfloat addition
// handle bfloat addition
if constexpr(is_same_v<scalar_t, bhalf_t>)
{
if constexpr(is_scalar_type<X>::value)
@@ -218,8 +187,6 @@ struct DynamicBuffer
{
this->template Set<X>(i, is_valid_element, x + tmp);
}
#endif
}
}
@@ -273,20 +240,9 @@ struct DynamicBuffer
if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
using vector_t = typename vector_type_maker<remove_cvref_t<T>, t_per_x>::type::type;
vector_t tmp;
if constexpr(is_same_v<remove_cvref_t<X>, vector_t>)
{
tmp = x;
}
else
{
__builtin_memcpy(&tmp, &x, sizeof(vector_t));
}
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
tmp, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&