From 70a8425dfb9576a5cb1c193288144386c7366682 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Fri, 5 Dec 2025 09:35:27 -0700 Subject: [PATCH] Congma/ck tile/aquant mem pipeline (#3346) * [CK TILE GEMM QUANT] Fix the bug in HotLoopTail of memory pipeline [ROCm/composable_kernel commit: ed080f5a56c38caea8fedbd0bcc2919ba2376a6f] --- .../run_gemm_quant_example.inc | 11 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 8 +- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 151 ++++++++++++++---- 3 files changed, 127 insertions(+), 43 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 396a54c7c2..0ee19b4a26 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -69,7 +69,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -128,7 +133,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::GemmPipelineAgBgCrCompV3, std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>, std::conditional_t, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 71e0ebb957..38a22e38ac 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -36,17 +36,13 @@ struct BaseGemmPipelineAgBgCrMem // TODO: Is this 32K value gfx9 arch specific? static constexpr index_t MinMemInFlyBytes = 32768; - static constexpr index_t WgpPerCU = - (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1; + static constexpr index_t WgpPerCU = ck_tile::max(4 * get_warp_size() / BlockSize, 1); static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(MinMemInFlyBytes / WgpPerCU, (MPerBlock * sizeof(ADataType) / APackedSize + NPerBlock * sizeof(BDataType) / BPackedSize) * KPerBlock); - static constexpr index_t PrefetchStages = - FullMemBandPrefetchStages >= 2 - ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 - : 2; + static constexpr index_t PrefetchStages = ck_tile::clamp(FullMemBandPrefetchStages, 2, 8); static constexpr index_t LocalPrefillStages = 1; static constexpr index_t GlobalBufferNum = PrefetchStages; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index f3c8b7a1a3..7f89d98349 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -80,6 +80,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -165,6 +168,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { using Base = PipelineImplBase; + template + CK_TILE_DEVICE static void + LoadAndConvertATile(ABlockTile_& a_block_tile, + ADramWindow& a_dram_window, + const DramTileWindowStep& dram_tile_window_step) + { + using DestDataType = typename ABlockTile_::DataType; + using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(a_block_tile, a_dram_window); + move_tile_window(a_dram_window, dram_tile_window_step); + } + template const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, const AQDramBlockWindowTmp& aq_dram_block_window_tmp, - index_t m, + [[maybe_unused]] index_t m, index_t num_loop, void* p_smem) const { - (void)m; // unused variable static_assert( std::is_same_v> && std::is_same_v std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); - static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], - "Aq block window has incorrect lengths for defined AqLayout!"); static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -217,7 +228,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem "B block window has incorrect lengths for defined BLayout!"); // A/B tiles in LDS - using the same approach as regular gemm pipeline - auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem); + auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = ab_lds_blocks.at(I1{}); @@ -249,7 +260,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -272,7 +283,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); // Global prefetch initialization - DRAM to VGPRs - Base::GlobalPrefetch( + LoadAndConvertATile( a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch( b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); @@ -282,10 +293,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS prefill - VGPRs to LDS - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -293,10 +304,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } @@ -306,9 +317,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } // Additional prefetching for memory pipeline - DRAM to VGPRs static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); @@ -325,16 +336,17 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, aq_block_tiles.get(number{}), a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); // Prepare next iteration data - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d( a_shuffle_tmp, @@ -348,7 +360,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -365,9 +377,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem b_element_func); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); @@ -381,20 +393,89 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } // Tail handling - block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - block_gemm( - c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + auto HotLoopTail = [&](auto tail_num) { + static_for<0, tail_num - 1, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + // no second block_sync_lds because it's interwave - if constexpr(TailNum == TailNumber::Even) - { + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, + a_block_tiles.get(number{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{})); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, + b_block_tiles.get(number{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{})); + } + }); - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I1{}), a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I1{}), b_element_func); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm( - c_block_tile, aq_block_tiles.get(I1{}), a_lds_gemm_window, b_lds_gemm_window); + c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + } + else if constexpr(TailNum == TailNumber::Two) + { + HotLoopTail(number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + HotLoopTail(number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + HotLoopTail(number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + HotLoopTail(number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + HotLoopTail(number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + HotLoopTail(number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + HotLoopTail(number{}); } return c_block_tile; } @@ -413,7 +494,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return PipelineImpl{} .template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const BDataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp,