mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
GEMM Blockscale ABQuant Optimization (#3620)
* GEMM Blockscale ABQuant Optimization * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix precommit error * clean * Fix --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Ding, Yi <yi.ding@amd.com>
This commit is contained in:
@@ -101,10 +101,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/**
|
||||
* @tparam nloop The number of iterations in the hot loop,
|
||||
* used to normalize scheduling costs.
|
||||
*/
|
||||
template <index_t nloop>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
static_assert(nloop > 0, "nloop must be greater than 0");
|
||||
// Estimated number of VMEM vector loads for A per block:
|
||||
// total A bytes / (threads per block * vector width)
|
||||
constexpr index_t Aload_inst =
|
||||
@@ -127,12 +131,13 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
// Total VMEM load instructions (A + B + quant data)
|
||||
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
|
||||
// Approximate number of LDS reads per block
|
||||
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle;
|
||||
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle / nloop;
|
||||
// Approximate number of LDS writes per block
|
||||
// (e.g., writing A from VMEM into LDS once per A load)
|
||||
constexpr index_t ds_write_inst = Aload_inst;
|
||||
// Number of MFMA instructions per wave for one block tile:
|
||||
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
|
||||
constexpr index_t mfma_inst =
|
||||
((kMPerBlock / WG::kM) / nloop) * ((kNPerBlock / WG::kN) / nloop);
|
||||
// How often (in MFMA units) we should insert DS (LDS) operations.
|
||||
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
|
||||
// How often (in MFMA units) we should insert VMEM buffer loads.
|
||||
@@ -169,7 +174,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
}
|
||||
// Always mark some VALU work in the loop to reflect auxiliary scalar
|
||||
// or vector ALU instructions that coexist with MFMA (Blockscale calculation).
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); // VALU
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -380,7 +385,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
|
||||
// Prefetch A1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
@@ -407,7 +411,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
while(iCounter > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// Prefill A(2i+1)
|
||||
// Prefill A(2i+1) ds_write
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
@@ -435,10 +439,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// prefetch Q(2i+1)
|
||||
aq_block_tile_2 = load_tile(aq_copy_dram_window);
|
||||
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
|
||||
// Preload A(2i+1) ds_read
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
@@ -460,6 +468,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// prefetch Q(2i+1)
|
||||
aq_block_tile = load_tile(aq_copy_dram_window);
|
||||
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
@@ -481,7 +491,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
aq_block_tile_2,
|
||||
bq_block_tile_2,
|
||||
a_warp_windows_pong);
|
||||
|
||||
// Preload A(2i+2) ds_read
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
@@ -521,7 +531,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
aq_block_tile,
|
||||
bq_block_tile,
|
||||
a_warp_windows_ping);
|
||||
|
||||
// Preload A ds_read
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
|
||||
Reference in New Issue
Block a user