mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Update flatmm related kernels (#3022)
--------- Co-authored-by: Ding, Yi <yi.ding@amd.com> Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
@@ -1303,6 +1303,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
||||
|
||||
// buffer atomic-add bf16
|
||||
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_bf16x2_t now.
|
||||
CK_TILE_DEVICE_EXTERN bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
bf16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
|
||||
|
||||
// buffer atomic-add i32
|
||||
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
@@ -1537,8 +1546,11 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, e8m0_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
@@ -2262,6 +2274,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
@@ -2355,6 +2368,39 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(bit_cast<bf16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
|
||||
@@ -1171,6 +1171,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
|
||||
|
||||
// buffer atomic-add bf16
|
||||
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_bf16x2_t now.
|
||||
CK_TILE_DEVICE_EXTERN bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
bf16x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
|
||||
|
||||
// buffer atomic-add i32
|
||||
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
@@ -1405,10 +1414,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, e8m0_bexp_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_fp4_raw_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
@@ -2047,6 +2060,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
@@ -2140,6 +2154,39 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(bit_cast<bf16x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
|
||||
src_thread_data.template get_as<bf16x2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(bf16x2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
|
||||
Reference in New Issue
Block a user