From dc83e285e10ba0cdd2dfee9a1208746301d712ed Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Thu, 22 Jan 2026 17:49:39 -0500 Subject: [PATCH] [CK TILE] simplify function GetKBPerLoad --- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1ff95b157c..e33d525e28 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -39,21 +39,17 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { + using BDataType = remove_cvref_t; using TileShape = typename Problem::BlockGemmShape; -#if defined(__gfx11__) - constexpr index_t scale = 4; -#else - constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; -#endif - if constexpr(TileShape::WarpTile::at(I1) == 32) - { - return TileShape::WarpTile::at(I2) * scale / 2; - } - else - { - static_assert(TileShape::WarpTile::at(I1) == 16); - return TileShape::WarpTile::at(I2) * scale / 4; - } + + constexpr index_t k_b_per_load = + TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size(); + + /* The k_b_per_load should meet the requirement that each thread loads 16 bytes in + * Preshuffle B */ + static_assert(k_b_per_load * sizeof(BDataType) == 16); + + return k_b_per_load; } template