diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1ff95b157c..e33d525e28 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -39,21 +39,17 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { + using BDataType = remove_cvref_t; using TileShape = typename Problem::BlockGemmShape; -#if defined(__gfx11__) - constexpr index_t scale = 4; -#else - constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; -#endif - if constexpr(TileShape::WarpTile::at(I1) == 32) - { - return TileShape::WarpTile::at(I2) * scale / 2; - } - else - { - static_assert(TileShape::WarpTile::at(I1) == 16); - return TileShape::WarpTile::at(I2) * scale / 4; - } + + constexpr index_t k_b_per_load = + TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size(); + + /* The k_b_per_load should meet the requirement that each thread loads 16 bytes in + * Preshuffle B */ + static_assert(k_b_per_load * sizeof(BDataType) == 16); + + return k_b_per_load; } template