[CK_TILE] Update flatmm related kernels (#3022)

---------

Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
lalala-sh
2025-10-22 22:36:11 +08:00
committed by GitHub
parent cbd1279ae6
commit 211d64e18a
39 changed files with 11183 additions and 739 deletions

View File

@@ -1303,6 +1303,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add bf16
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_bf16x2_t now.
CK_TILE_DEVICE_EXTERN bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
bf16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
@@ -1537,8 +1546,11 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, e8m0_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;
@@ -2262,6 +2274,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
@@ -2355,6 +2368,39 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
});
}
}
else if constexpr(std::is_same<T, bf16_t>::value)
{
if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(bit_cast<bf16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
static_for<0, 2, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
else if constexpr(N == 8)
{
static_for<0, 4, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)

View File

@@ -1171,6 +1171,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add bf16
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_bf16x2_t now.
CK_TILE_DEVICE_EXTERN bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
bf16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
@@ -1405,10 +1414,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, e8m0_bexp_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_fp4_raw_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;
@@ -2047,6 +2060,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
@@ -2140,6 +2154,39 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
});
}
}
else if constexpr(std::is_same<T, bf16_t>::value)
{
if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(bit_cast<bf16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
static_for<0, 2, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
else if constexpr(N == 8)
{
static_for<0, 4, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -88,7 +89,12 @@ template <typename T, typename = void>
struct vector_traits
{
using scalar_type =
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>,
int8_t,
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t> ||
std::is_same_v<remove_cvref_t<T>, e8m0_t>,
uint8_t,
remove_cvref_t<T>>>;
static constexpr index_t vector_size = 1;
};
@@ -96,7 +102,12 @@ struct vector_traits
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
{
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
using scalar_type = std::conditional_t<
std::is_same_v<T, pk_int4_t>,
int8_t,
std::conditional_t<std::is_same_v<T, pk_fp4_t> || std::is_same_v<remove_cvref_t<T>, e8m0_t>,
uint8_t,
T>>;
static constexpr index_t vector_size = N;
};
@@ -237,4 +248,10 @@ using pk_int4x4_t = int8_t __attribute__((ext_vector_type(4)));
using pk_int4x8_t = int8_t __attribute__((ext_vector_type(8)));
using pk_int4x16_t = int8_t __attribute__((ext_vector_type(16)));
using pk_int4x32_t = int8_t __attribute__((ext_vector_type(32)));
using pk_fp4x2_t = uint8_t __attribute((ext_vector_type(2)));
using pk_fp4x4_t = uint8_t __attribute((ext_vector_type(4)));
using pk_fp4x8_t = uint8_t __attribute((ext_vector_type(8)));
using pk_fp4x16_t = uint8_t __attribute((ext_vector_type(16)));
using pk_fp4x32_t = uint8_t __attribute((ext_vector_type(32)));
} // namespace ck_tile

View File

@@ -247,7 +247,7 @@ struct buffer_view<address_space_enum::global,
: p_data_{p_data},
buffer_size_{buffer_size / PackedSize},
cached_buf_res_{0},
invalid_element_value_{0}
invalid_element_value_{}
{
}
@@ -631,14 +631,24 @@ struct buffer_view<address_space_enum::global,
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
#endif
;
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
#endif
;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif

View File

@@ -1,3 +1,4 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
@@ -404,6 +405,100 @@ struct tile_scatter_gather
});
}
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// Precompute invariant values outside loops
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
auto lds_window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto lds_bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// Use precomputed window origin
auto lds_bottom_tensor_thread_idx =
window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
// Use precomputed tensor descriptor
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// merge page_offset into bottom_coord
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
// read from bottom tensor
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
mixed_bottom_thread_coord,
number<0>{},
bool_constant<oob_conditional_check>{});
else
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
mixed_bottom_thread_coord,
number<0>{},
valids_[idx_gather],
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
// lds_diff doesn't need to mask the difference of the gather-dim.
constexpr auto lds_idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
lds_window_adaptor_thread_coord,
lds_bottom_tensor_thread_coord,
lds_idx_diff_ps_ys);
}
});
});
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
@@ -508,6 +603,88 @@ struct tile_scatter_gather
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
const auto page_offset = page_idx_[idx_gather];
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
}
else
{
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
vec_value,
bool_constant<oob_conditional_check>{});
}
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
@@ -855,4 +1032,29 @@ CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
number<HsGatherDim>{});
}
template <typename NewTensorView_,
typename OldTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1>
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
const tile_scatter_gather<OldTensorView_,
WindowLengths_,
StaticTileDistribution_,
StaticPageIndexArray_,
StaticValidArray_,
HsGatherDim,
NumCoord>& tile_window)
{
return make_tile_scatter_gather(new_tensor_view,
tile_window.window_lengths_,
tile_window.window_origin_,
tile_window.tile_dstr_,
tile_window.page_idx_,
tile_window.valids_);
}
} // namespace ck_tile

View File

@@ -1153,6 +1153,33 @@ CK_TILE_DEVICE void move_tile_window(
window.move(step);
}
template <typename NewTensorView_,
typename OldTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1>
CK_TILE_DEVICE auto
replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
const tile_window_with_static_distribution<OldTensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>& tile_window)
{
return make_tile_window(new_tensor_view,
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_window.get_tile_distribution());
}
template <typename NewTensorView_, typename OldTensorView_, typename WindowLengths_>
CK_TILE_DEVICE auto replace_bottom_tensor_view(
const NewTensorView_& new_tensor_view,
const tile_window_with_static_lengths<OldTensorView_, WindowLengths_>& tile_window)
{
return make_tile_window(
new_tensor_view, tile_window.get_window_lengths(), tile_window.get_window_origin());
}
/**
* @brief Type trait to determine if a type is a tile window with static distribution.
*