From 080fa14140700d83656558203283d155be0e3d2d Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Thu, 22 Jan 2026 12:36:40 -0500 Subject: [PATCH] [CK TILE] Add new function get_k_warp_tile_for_preshuffle_b --- example/ck_tile/03_gemm/gemm_utils.hpp | 4 ++-- .../03_gemm/gemm_weight_preshuffle.cpp | 3 +++ .../ops/gemm/pipeline/tile_gemm_shape.hpp | 20 +++++++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index c1df27ecc8..c1a37c8577 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -255,7 +255,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -280,7 +280,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 85f8c346c9..d4c55de9e7 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -94,7 +94,10 @@ int main(int argc, char* argv[]) auto result = arg_parser.parse(argc, argv); if(!result) + { + arg_parser.print(); return -1; + } try { diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 525a4ef9fc..429522ac68 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -66,4 +66,24 @@ constexpr index_t get_k_warp_tile() #endif } +template +constexpr index_t get_k_warp_tile_for_preshuffle_b() +{ +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(N_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + // K value is determined by the maximum bytes that can be loaded in a single instruction + // This K value is sufficient for MFMA/WMMA shapes: 16x16x16, 16x16x32, 32x32x16 + const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes + const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType); + const int KLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; + return kMaxElementsPerLoad * KLanePerWarp; +#endif +} + } // namespace ck_tile