[rocm-libraries] ROCm/rocm-libraries#5842 (commit 04c5690)

[CK][CK Tile] Force padding for atomic_add bf16 C tensor
 (#5842)

## Motivation

Force padding for atomic_add bf16 C tensor to avoid memfaults.

## Technical Details

- add global atomic add for bf16 and enable them
- add padding for atomic add bf16 due to the lack of oob
- remove padding for not continous dims in conv for other cases
- minor bwd data conv fixes

## Test Plan

test_grouped_conv_*_tile

## Test Result

pending

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-03-31 08:03:41 +00:00
committed by assistant-librarian[bot]
parent 66dc81d530
commit ef4ff4667d
7 changed files with 174 additions and 171 deletions

View File

@@ -18,6 +18,10 @@
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
// This attribute gives a hint to the compiler that a branch is likely to be taken.
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
// have been generated.
@@ -2317,6 +2321,34 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
}
}
template <typename T, index_t N>
CK_TILE_DEVICE void
amd_global_atomic_add_impl([[maybe_unused]] const thread_buffer<T, N>& src_thread_data,
[[maybe_unused]] T* addr)
{
static_assert((std::is_same<T, ck_tile::bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, ck_tile::fp16_t>::value && (N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
if constexpr(__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16) &&
std::is_same<T, ck_tile::bf16_t>::value)
{
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(
bit_cast<ck_tile::bf16x2_t*>(addr) + i,
src_thread_data.template get_as<ck_tile::bf16x2_t>()[i]);
});
}
else
{
static_assert(false, "Not supported!");
}
#else
static_assert(false, "Not supported!");
#endif
}
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_thread_data,
int32x4_t dst_wave_buffer_resource,
@@ -2325,8 +2357,11 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4))
#if defined(__gfx950__)
|| (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8))
#endif
,
"wrong! not implemented");
if constexpr(std::is_same<T, float>::value)
@@ -2931,16 +2966,27 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
#if defined(__gfx942__)
if constexpr(std::is_same<T, bf16_t>::value)
{
if(dst_thread_element_valid)
{
amd_global_atomic_add_impl<T, N>(src_thread_data,
p_dst_wave + dst_thread_element_offset);
}
}
else
{
#endif
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
amd_buffer_atomic_add_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
@@ -2948,6 +2994,9 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
#if defined(__gfx942__)
}
#endif
}
template <typename T,

View File

@@ -18,6 +18,10 @@
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
namespace ck_tile {
@@ -2143,6 +2147,33 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
}
}
template <typename T, index_t N>
CK_TILE_DEVICE void
amd_global_atomic_add_impl([[maybe_unused]] const thread_buffer<T, N>& src_thread_data,
[[maybe_unused]] T* addr)
{
static_assert((std::is_same<T, ck_tile::bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, ck_tile::fp16_t>::value && (N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
if constexpr(std::is_same<T, ck_tile::bf16_t>::value)
{
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(
bit_cast<ck_tile::bf16x2_t*>(addr) + i,
src_thread_data.template get_as<ck_tile::bf16x2_t>()[i]);
});
}
else
{
static_assert(false, "Not supported!");
}
#else
static_assert(false, "Not supported!");
#endif
}
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_thread_data,
int32x4_t dst_wave_buffer_resource,
@@ -2151,8 +2182,11 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4))
#if defined(__gfx950__)
|| (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8))
#endif
,
"wrong! not implemented");
if constexpr(std::is_same<T, float>::value)
@@ -2759,16 +2793,28 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
#if defined(__gfx942__)
if constexpr(std::is_same<T, bf16_t>::value)
{
if(dst_thread_element_valid)
{
amd_global_atomic_add_impl<T, N>(src_thread_data,
p_dst_wave + dst_thread_element_offset);
}
}
else
{
#endif
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
amd_buffer_atomic_add_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
@@ -2776,6 +2822,9 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
#if defined(__gfx942__)
}
#endif
}
template <typename T,

View File

@@ -630,7 +630,7 @@ struct buffer_view<address_space_enum::global,
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
#if defined(__gfx942__) || defined(__gfx950__) // only gfx942 and gfx950 support atomic_pk_add_bf16
||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
#endif
@@ -642,7 +642,7 @@ struct buffer_view<address_space_enum::global,
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
#if defined(__gfx942__) || defined(__gfx950__) // only gfx942 and gfx950 support atomic_pk_add_bf16
||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
#endif