mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Ck tile/gemm blockscale opt (#3227)
* GEMM block scale optimization kernel * GEMM block scale optimization kernel * Fix: Apply clang-format for style consistency * Fix: Apply clang-format for style consistency --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -69,7 +69,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
using Base::m_preload;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
@@ -95,6 +96,56 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <index_t nloop>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t Aload_inst =
|
||||
(kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize;
|
||||
constexpr index_t Bload_inst =
|
||||
(kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize;
|
||||
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
|
||||
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
|
||||
QuantGroupSize::kK * QuantGroupSize::kK),
|
||||
VectorLoadSize);
|
||||
constexpr index_t kLdsVec = 8;
|
||||
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
|
||||
constexpr index_t ds_read_inst = kMPerBlock / kLdsVec;
|
||||
constexpr index_t ds_write_inst = Aload_inst;
|
||||
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
|
||||
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
|
||||
constexpr index_t buffer_load_rep =
|
||||
min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma
|
||||
|
||||
static_for<0, nloop, 1>{}([&](auto j_inst) {
|
||||
ignore = j_inst;
|
||||
static_for<0, mfma_inst, 1>{}([&](auto i_inst) {
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA
|
||||
|
||||
if constexpr(ds_rep > 0 && i_inst % ds_rep == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
|
||||
}
|
||||
if constexpr(ds_rep > 0 && i_inst % ds_rep == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
|
||||
}
|
||||
|
||||
if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0)
|
||||
{
|
||||
if constexpr(ds_write_inst > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
|
||||
}
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
static constexpr bool PreshuffleB = Problem::PreshuffleB;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
@@ -130,6 +181,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
static_assert(!is_b_row_major, "B must be col major (row major not supported yet)");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
// Double-Buffering (loop_count=2) for full load/compute overlap.
|
||||
const index_t loop_count = 2;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -313,9 +366,26 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
index_t iCounter = (num_loop - 1) / loop_count;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor_ping,
|
||||
bq_block_tile,
|
||||
a_warp_windows_ping);
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
@@ -342,29 +412,12 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
}
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor_ping,
|
||||
bq_block_tile,
|
||||
a_warp_windows_ping);
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
Base::HotLoopScheduler();
|
||||
|
||||
// Next K
|
||||
|
||||
@@ -416,9 +469,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
Base::HotLoopScheduler();
|
||||
|
||||
iCounter--;
|
||||
HotLoopScheduler<loop_count>();
|
||||
}
|
||||
|
||||
// tail
|
||||
@@ -456,15 +508,13 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
Base::Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor_pong,
|
||||
bq_block_tile_2,
|
||||
a_warp_windows_pong);
|
||||
Base::LastHotLoopScheduler();
|
||||
HotLoopScheduler<loop_count>();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user