From 7afee6c5367b3ffb9d34aa491bc9cae0afa7ef22 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Tue, 10 Jun 2025 11:13:40 -0700 Subject: [PATCH] fix flatmm kernel for bigger size for fp16 datatype (#2302) [ROCm/composable_kernel commit: bd270fe4bcee5d2f8d9b011ca3e3fbcd1899900a] --- example/ck_tile/18_flatmm/CMakeLists.txt | 4 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 77 +++++++------------ example/ck_tile/18_flatmm/flatmm_basic.hpp | 37 +++++++++ .../ck_tile/18_flatmm/run_flatmm_example.inc | 65 ++++++---------- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 4 +- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 2 +- 6 files changed, 91 insertions(+), 98 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index f4d823e91a..58e06f3c0f 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -3,6 +3,6 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) -#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -Wno-unused-local-typedef) +#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -Wno-unused-local-typedef) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 2dbff1bc5c..c564d7d1b1 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -22,49 +22,22 @@ template float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) { - // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 2; - - // This part comes from the Codegen -#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16) - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 128; - - constexpr ck_tile::index_t M_Warp = 1; - constexpr ck_tile::index_t N_Warp = 4; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 16 : 32; - constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 16 : 32; - constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 64 : 16; - -#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 128; - - constexpr ck_tile::index_t M_Warp = 1; - constexpr ck_tile::index_t N_Warp = 8; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 32 : 32; - constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 32 : 32; - constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 32 : 16; -#endif - using CodegenFlatmmShape = - ck_tile::TileFlatmmShape, - ck_tile::sequence, - ck_tile::sequence>; + using FlatmmConfig = FlatmmConfig; + using CodegenFlatmmShape = ck_tile::TileFlatmmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using CodegenGemmTraits = - ck_tile::TileGemmTraits; + using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile::GemmPipelineProblem>; @@ -110,8 +83,9 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() + << CodegenPipelineProblem::GetName() << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } @@ -150,12 +124,15 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con ave_time = ck_tile::launch_kernel_preprocess( s, run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 55f2d4f367..6b52ce8b1b 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -109,6 +109,43 @@ struct is_8bit_type { }; +template +struct FlatmmConfig +{ +#if defined(USING_MFMA_16x16x32) + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 16 : 32; + static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 16 : 32; + static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 64 : 16; + +#elif defined(USING_MFMA_32x32x16) + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 8; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 32 : 32; + static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 32 : 32; + static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 32 : 16; +#endif + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr int kBlockPerCu = 2; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 3d4f154af7..1607fb6163 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -32,38 +32,20 @@ static constexpr inline auto is_row_major(Layout layout_) } // mfma_type, 0:32x32, 1:16x16 -template -auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type) +template +auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - - if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) - { - ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 16, 2, 8}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) - { - ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 32, 4, 8}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0) - { - ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1) - { - ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - return t; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, + FlatmmConfig::N_Warp_Tile, + k_ / FlatmmConfig::K_Warp_Tile, + divisor, + FlatmmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } template @@ -149,10 +131,11 @@ int run_flatmm_example_with_layouts(int argc, if(!result) return -1; - using ADataType = typename GemmBasicTypeConfig::ADataType; - using BDataType = typename GemmBasicTypeConfig::BDataType; - using CDataType = typename GemmBasicTypeConfig::CDataType; - using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using FlatmmConfig = FlatmmConfig; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -163,8 +146,9 @@ int run_flatmm_example_with_layouts(int argc, ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t kbatch = arg_parser.get_int("split_k"); - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); + + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); @@ -188,13 +172,8 @@ int run_flatmm_example_with_layouts(int argc, c_rslt_host.SetZero(); // do pre-shuffle - std::string mfma = arg_parser.get_str("prec"); -#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) - ck_tile::index_t mfma_type = 1; -#else - ck_tile::index_t mfma_type = 0; -#endif - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, mfma_type); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host); + ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index cbd20a6ea3..aa4d233ecb 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -75,7 +75,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { -#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) || defined(USING_MFMA_32x32x16) +#if defined(USING_MFMA_16x16x32) || defined(USING_MFMA_32x32x16) constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -92,7 +92,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; #endif -#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) +#if defined(USING_MFMA_16x16x32) static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 7d06d871a9..91323d2c39 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -19,7 +19,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { using namespace ck_tile; -#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) +#if defined(USING_MFMA_16x16x32) /*reduce transform layers,compare with old ck*/ constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;