[CK_TILE] Update flatmm related kernels (#3022)

---------

Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
lalala-sh
2025-10-22 22:36:11 +08:00
committed by GitHub
parent cbd1279ae6
commit 211d64e18a
39 changed files with 11183 additions and 739 deletions

View File

@@ -1303,6 +1303,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add bf16
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_bf16x2_t now.
CK_TILE_DEVICE_EXTERN bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
bf16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
@@ -1537,8 +1546,11 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, e8m0_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;
@@ -2262,6 +2274,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
@@ -2355,6 +2368,39 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
});
}
}
else if constexpr(std::is_same<T, bf16_t>::value)
{
if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(bit_cast<bf16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
static_for<0, 2, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
else if constexpr(N == 8)
{
static_for<0, 4, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)

View File

@@ -1171,6 +1171,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add bf16
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_bf16x2_t now.
CK_TILE_DEVICE_EXTERN bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
bf16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
@@ -1405,10 +1414,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, e8m0_bexp_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_fp4_raw_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;
@@ -2047,6 +2060,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
@@ -2140,6 +2154,39 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
});
}
}
else if constexpr(std::is_same<T, bf16_t>::value)
{
if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(bit_cast<bf16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
static_for<0, 2, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
else if constexpr(N == 8)
{
static_for<0, 4, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)