From d2ec53a74eaeedc672deeeb239177aea23a616c8 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 25 Jun 2025 16:07:45 +0800 Subject: [PATCH] [CK_TILE] Refine fp8 support in flatmm (#2239) * [CK_TILE] Refine fp8 in flatmm 1. Replace USING_MFMA_16x16x32 & USING_MFMA_16x16x32 with constexpr 2. Add an additional const check to avoid build error in HotLoopScheduler 3. Refine shuffleb to support both tile 32x32 and 16x16 4. Support command option -init 5. Move Gemm warp defintion to a separate struct * fix clang format * fix clang format * keep default bhavior unchanged (warp tile = 16x16) * fix tile engine build error * fix a typo in codegen_utils.py * address review comments * address review comments --------- Co-authored-by: Thomas Ning [ROCm/composable_kernel commit: 37e1a2753702f003b751425502e037f2384aaa5f] --- example/ck_tile/18_flatmm/CMakeLists.txt | 2 - example/ck_tile/18_flatmm/flatmm_basic.cpp | 44 +++++-- example/ck_tile/18_flatmm/flatmm_basic.hpp | 109 +++++++++------- .../ck_tile/18_flatmm/run_flatmm_example.inc | 91 +++++++++----- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 119 +++++++++--------- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 115 ++++++++++------- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 13 +- tile_engine/ops/gemm/codegen_utils.py | 3 + tile_engine/ops/gemm/gemm_instance_builder.py | 11 +- tile_engine/ops/gemm/gemm_profiler.hpp | 4 +- 10 files changed, 313 insertions(+), 198 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 58e06f3c0f..6d6b71ea18 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -3,6 +3,4 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -Wno-unused-local-typedef) -#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -Wno-unused-local-typedef) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 8782d2bb6a..f96f558101 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -17,12 +17,12 @@ template float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) { - using FlatmmConfig = FlatmmConfig; using CodegenFlatmmShape = ck_tile::TileFlatmmShape< ck_tile::sequence, ck_tile::sequence, @@ -32,18 +32,20 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using CodegenGemmTraits = ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; - const auto Run = [&](const auto memory_operation_) { + + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using GemmEpilogue = ck_tile::CShuffleEpilogue< @@ -151,6 +153,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con } } +template