diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 462f11e405..905d3ffc72 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -148,7 +148,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool kPadK = true; @@ -174,7 +174,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -220,7 +220,8 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool kPadK = true; diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp index a1f287df6b..2ea28ec558 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp @@ -84,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 37fc998e5b..37e46c6b04 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -145,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; @@ -175,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index 7cd9889d78..7f16a4bde0 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -77,37 +77,35 @@ auto shuffle_b(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) if(ck_tile::is_gfx12_supported()) { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + constexpr int kKLanePerWarp = 2; + constexpr int kABK1PerLane = 8; + int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane; ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, gemmConfig.N_Warp_Tile, k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, - divisor, + kKLanePerWarp, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } else { - int divisor = 1; + int kKLanePerWarp = 1; if(ck_tile::is_gfx11_supported()) { - divisor = 1; + kKLanePerWarp = 1; } else { - assert(is_wave32() == false); - divisor = get_warp_size() / gemmConfig.N_Warp_Tile; + kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile; } ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, gemmConfig.N_Warp_Tile, - k_ / gemmConfig.K_Warp_Tile, - divisor, - gemmConfig.K_Warp_Tile / divisor}); + k_ / (gemmConfig.K_Warp_Tile / kKLanePerWarp), + gemmConfig.K_Warp_Tile / kKLanePerWarp}); std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + return ck_tile::reference_permute(t_view, {0, 2, 1, 3}); } } @@ -144,39 +142,39 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor& t, const GemmConfig& gemmC int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp; if(ck_tile::is_gfx12_supported()) { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + constexpr int kKLanePerWarp = 2; + constexpr int kABK1PerLane = 8; + int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane; ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat, k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, - divisor, + kKLanePerWarp, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7}); } else { - int divisor = 1; + int kKLanePerWarp = 1; if(ck_tile::is_gfx11_supported()) { - divisor = 1; + kKLanePerWarp = 1; } else { assert(is_wave32() == false); - divisor = get_warp_size() / gemmConfig.N_Warp_Tile; + kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile; } ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat, k_ / gemmConfig.K_Warp_Tile, - divisor, - gemmConfig.K_Warp_Tile / divisor}); + kKLanePerWarp, + gemmConfig.K_Warp_Tile / kKLanePerWarp}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); } diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index b9382dee84..63bec56e20 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -69,26 +69,38 @@ constexpr index_t get_k_warp_tile() template constexpr index_t get_k_warp_tile_for_preshuffle_b() { +#if CK_TILE_USE_WMMA + return 16; +#else + // When preshuffle B is enabled, the K_Warp_Tile must be sized appropriately + // to support both dwordx4 loading instructions and MFMA instruction requirements. + // A single dwordx4 load may feed one or more MFMA instructions, or conversely, + // multiple loads may be required for a single MFMA instruction with a larger K dimension + // (e.g., 16x16x128 on gfx950). + + // To achieve optimal memory bandwidth, each thread loads a minimum of 16 bytes (dwordx4) + // from global memory. const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType); - const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; - const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp; + const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; + const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp; + // Minimum K_Warp_Tile required by MFMA instructions const index_t kMfmaN16Index = 0; const index_t kMfmaN32Index = 1; #if defined(CK_GFX950_SUPPORT) - const index_t kF8MfmaMaxK[2] = {128, 64}; + const index_t kF8MfmaMaxK[2] = {128, 64}; const index_t kF16MfmaMaxK[2] = {32, 16}; #else - const index_t kF8MfmaMaxK[2] = {32, 16}; + const index_t kF8MfmaMaxK[2] = {32, 16}; const index_t kF16MfmaMaxK[2] = {16, 8}; #endif - const bool kIsF8 = - std::is_same_v || std::is_same_v; - const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index; - const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex]; + const bool kIsF8 = std::is_same_v || std::is_same_v; + const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index; + const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex]; return max(kKPerWarp, kMfmaMaxK); +#endif } } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index e33d525e28..d1e498361a 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -39,16 +39,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { - using BDataType = remove_cvref_t; using TileShape = typename Problem::BlockGemmShape; constexpr index_t k_b_per_load = TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size(); - /* The k_b_per_load should meet the requirement that each thread loads 16 bytes in - * Preshuffle B */ - static_assert(k_b_per_load * sizeof(BDataType) == 16); - return k_b_per_load; } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 8c9955da74..4dbc122110 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -16,6 +16,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" // Forward declarations for quant type-specific implementations template @@ -74,11 +75,14 @@ class TestCkTileGemmQuantBase : public ::testing::Test static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; - static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; - static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant; - static constexpr bool PreshuffleB = GemmConfig::PreshuffleB; - static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN; - static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; + static constexpr ck_tile::index_t K_Warp_Tile = + GemmConfig::PreshuffleB + ? ck_tile::get_k_warp_tile_for_preshuffle_b() + : ck_tile::get_k_warp_tile(); + static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant; + static constexpr bool PreshuffleB = GemmConfig::PreshuffleB; + static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN; + static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; static constexpr bool kPadM = GemmConfig::kPadM; static constexpr bool kPadN = GemmConfig::kPadN; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 79c86935ef..bee2e7ed71 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -6,16 +6,7 @@ #include "test_gemm_quant_base.hpp" #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" - -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if CK_TILE_USE_WMMA - return 16; -#else - return is_8bit ? 64 : 32; -#endif -} +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" struct GemmConfigBase { @@ -50,23 +41,21 @@ struct GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + // K_Warp_Tile is derived from N_Warp_Tile and BDataType }; struct GemmConfigDecode : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; }; struct GemmConfigPrefill : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; }; struct GemmConfigMxFp4 : public GemmConfigBase diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index f6620c105d..77ed9f9bb6 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" using AddScale = ck_tile::element_wise::AddScale; using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd; @@ -23,28 +24,6 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if CK_TILE_USE_WMMA - return 16; -#else -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -#endif -} - template constexpr ck_tile::index_t get_k_warp_tile() @@ -86,7 +87,7 @@ struct config static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template @@ -102,7 +103,7 @@ struct config_wmma static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index e588ad2cc1..a490cf42f1 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -48,7 +48,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test static const ck_tile::index_t M_Warp_Tile = 16; static const ck_tile::index_t N_Warp_Tile = 16; static const ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem static constexpr bool TransposeC = false; // transpose c is not supported