From 42d775e4888ef68b770e66ead583cd1275c5364d Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Fri, 15 Aug 2025 15:49:07 -0700 Subject: [PATCH] Preshuffle Decode Prefill config fix (#2693) * feat(gemm_wp): add two new configs for wp * delete the unnecessary files * fix the config error * update the config --------- Co-authored-by: AviralGoelAMD [ROCm/composable_kernel commit: 5ada85ec047591dc2d67b3e608c1951156b5ef4f] --- example/ck_tile/03_gemm/gemm_utils.hpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) mode change 100644 => 100755 example/ck_tile/03_gemm/gemm_utils.hpp diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp old mode 100644 new mode 100755 index ab481b97a0..e319e2d668 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -35,6 +35,22 @@ constexpr ck_tile::index_t get_k_warp_tile() #endif } +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(CK_GFX950_SUPPORT) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + struct GemmConfigBase { static constexpr bool kPadM = false; @@ -229,7 +245,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -251,7 +267,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;