diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 116661c157..2b2333b04c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -211,6 +211,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr int kBlockPerCu = 2; }; template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 59a5b0df4e..d83338fbb2 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -69,7 +69,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV using Base::m_preload; - static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; static constexpr index_t KPerBlockBQ = integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = @@ -95,6 +96,56 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV // clang-format on } + template + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t Aload_inst = + (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + constexpr index_t Bload_inst = + (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( + ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), + QuantGroupSize::kK * QuantGroupSize::kK), + VectorLoadSize); + constexpr index_t kLdsVec = 8; + constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; + constexpr index_t ds_read_inst = kMPerBlock / kLdsVec; + constexpr index_t ds_write_inst = Aload_inst; + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + constexpr index_t buffer_load_rep = + min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma + + static_for<0, nloop, 1>{}([&](auto j_inst) { + ignore = j_inst; + static_for<0, mfma_inst, 1>{}([&](auto i_inst) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + + if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read + } + if constexpr(ds_rep > 0 && i_inst % ds_rep == 1) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write + } + + if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0) + { + if constexpr(ds_write_inst > 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read + } + } + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + static constexpr bool PreshuffleB = Problem::PreshuffleB; static constexpr auto TailNum = Problem::TailNum; @@ -130,6 +181,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV static_assert(!is_b_row_major, "B must be col major (row major not supported yet)"); const index_t iMWarp = get_warp_id() / NWarp; + // Double-Buffering (loop_count=2) for full load/compute overlap. + const index_t loop_count = 2; __builtin_amdgcn_sched_barrier(0); @@ -313,9 +366,26 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV __builtin_amdgcn_sched_barrier(0); // MAIN LOOP - index_t iCounter = (num_loop - 1) / 2; + index_t iCounter = (num_loop - 1) / loop_count; + while(iCounter > 0) { + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + bq_block_tile, + a_warp_windows_ping); // prefetch B(2i+1) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -342,29 +412,12 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); } - // Prefill A(2i+1) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); - - // Prefetch A(2i+2) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // GEMM 2i - block_weight_preshuffle(c_block_tile, - a_warp_tensor, - b_warp_tensor_ping, - bq_block_tile, - a_warp_windows_ping); - static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); }); - Base::HotLoopScheduler(); // Next K @@ -416,9 +469,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); }); - Base::HotLoopScheduler(); - iCounter--; + HotLoopScheduler(); } // tail @@ -456,15 +508,13 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV load_tile(a_warp_windows_pong(number{})(number{})); }); - Base::Last2ndHotLoopScheduler(); - // GEMM loopK block_weight_preshuffle(c_block_tile, a_warp_tensor, b_warp_tensor_pong, bq_block_tile_2, a_warp_windows_pong); - Base::LastHotLoopScheduler(); + HotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) {