Ck tile/gemm blockscale opt (#3227)

* GEMM block scale optimization kernel

* GEMM block scale optimization kernel

* Fix: Apply clang-format for style consistency

* Fix: Apply clang-format for style consistency

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
kensclin
2025-12-04 14:07:23 +08:00
committed by GitHub
parent eb7f617713
commit ffc3120f63
2 changed files with 75 additions and 24 deletions

View File

@@ -69,7 +69,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
using Base::m_preload;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
static constexpr index_t QScalesPerBlockRow =
@@ -95,6 +96,56 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// clang-format on
}
template <index_t nloop>
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t Aload_inst =
(kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize;
constexpr index_t Bload_inst =
(kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize;
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
QuantGroupSize::kK * QuantGroupSize::kK),
VectorLoadSize);
constexpr index_t kLdsVec = 8;
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
constexpr index_t ds_read_inst = kMPerBlock / kLdsVec;
constexpr index_t ds_write_inst = Aload_inst;
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
constexpr index_t buffer_load_rep =
min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma
static_for<0, nloop, 1>{}([&](auto j_inst) {
ignore = j_inst;
static_for<0, mfma_inst, 1>{}([&](auto i_inst) {
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA
if constexpr(ds_rep > 0 && i_inst % ds_rep == 0)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
}
if constexpr(ds_rep > 0 && i_inst % ds_rep == 1)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
}
if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0)
{
if constexpr(ds_write_inst > 0)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
}
}
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
});
});
__builtin_amdgcn_sched_barrier(0);
}
static constexpr bool PreshuffleB = Problem::PreshuffleB;
static constexpr auto TailNum = Problem::TailNum;
@@ -130,6 +181,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
static_assert(!is_b_row_major, "B must be col major (row major not supported yet)");
const index_t iMWarp = get_warp_id() / NWarp;
// Double-Buffering (loop_count=2) for full load/compute overlap.
const index_t loop_count = 2;
__builtin_amdgcn_sched_barrier(0);
@@ -313,9 +366,26 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
__builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
index_t iCounter = (num_loop - 1) / loop_count;
while(iCounter > 0)
{
__builtin_amdgcn_sched_barrier(0);
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
// Prefetch A(2i+2)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i
block_weight_preshuffle(c_block_tile,
a_warp_tensor,
b_warp_tensor_ping,
bq_block_tile,
a_warp_windows_ping);
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
@@ -342,29 +412,12 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
}
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
// Prefetch A(2i+2)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i
block_weight_preshuffle(c_block_tile,
a_warp_tensor,
b_warp_tensor_ping,
bq_block_tile,
a_warp_windows_ping);
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
Base::HotLoopScheduler();
// Next K
@@ -416,9 +469,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
Base::HotLoopScheduler();
iCounter--;
HotLoopScheduler<loop_count>();
}
// tail
@@ -456,15 +508,13 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
Base::Last2ndHotLoopScheduler();
// GEMM loopK
block_weight_preshuffle(c_block_tile,
a_warp_tensor,
b_warp_tensor_pong,
bq_block_tile_2,
a_warp_windows_pong);
Base::LastHotLoopScheduler();
HotLoopScheduler<loop_count>();
}
else if constexpr(TailNum == TailNumber::Odd)
{