[CK_TILE] Refine fp8 support in flatmm (#2239)

* [CK_TILE] Refine fp8 in flatmm

1. Replace USING_MFMA_16x16x32 & USING_MFMA_16x16x32 with constexpr
2. Add an additional const check to avoid build error in HotLoopScheduler
3. Refine shuffleb to support both tile 32x32 and 16x16
4. Support command option -init
5. Move Gemm warp defintion to a separate struct

* fix clang format

* fix clang format

* keep default bhavior unchanged (warp tile = 16x16)

* fix tile engine build error

* fix a typo in codegen_utils.py

* address review comments

* address review comments

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>

[ROCm/composable_kernel commit: 37e1a27537]
This commit is contained in:
linqunAMD
2025-06-25 16:07:45 +08:00
committed by GitHub
parent b62e551ccb
commit d2ec53a74e
10 changed files with 313 additions and 198 deletions

View File

@@ -44,9 +44,12 @@ CSHUFFLE_EPILOGUE = """
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,

View File

@@ -193,7 +193,7 @@ struct GemmKernel {{
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
static constexpr bool permuteA = false;
static constexpr bool permuteB = false;
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
@@ -306,7 +306,7 @@ struct GemmKernel {{
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
}};
ave_time = ck_tile::launch_kernel_preprocess(
stream,
@@ -570,12 +570,13 @@ struct GemmDispatcher {
// Use a static local variable
static std::unordered_map<
std::string,
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>>
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>>
kernel_map;
return kernel_map;
}
static void init(bool structured_sparsity) {
ck_tile::ignore = structured_sparsity;
auto& kernel_map = get_kernel_map();
if(!kernel_map.empty()) return;
\n"""
@@ -586,7 +587,7 @@ struct GemmDispatcher {
for j in range(len(tile)):
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[
j]
content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
content += f"""
if(structured_sparsity){{ // SMFMA"""
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
@@ -615,7 +616,7 @@ struct GemmDispatcher {
content += """ }
template <typename Kernel>
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream)
{
std::string name = Kernel::get_name();
float avg_time = Kernel::launch(args, stream);

View File

@@ -22,7 +22,7 @@ class GemmProfiler
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
@@ -89,7 +89,7 @@ class GemmProfiler
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs gemm_args;
ck_tile::GemmHostArgs<> gemm_args;
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();