GEMM Blockscale ABQuant Optimization (#3620)

* GEMM Blockscale ABQuant Optimization

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix precommit error

* clean

* Fix

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Ding, Yi <yi.ding@amd.com>
This commit is contained in:
kensclin
2026-01-23 01:39:38 +08:00
committed by GitHub
parent 9e049a32a1
commit 31a35ecab4
7 changed files with 161 additions and 51 deletions

View File

@@ -213,6 +213,22 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
});
});
};
auto q_block_tensor = aq_block_tensor;
constexpr bool SimpleDequant =
Traits::NQPerBlock == 1 &&
AccTensor::get_distributed_spans()[I0].impl_.size() == 0; // c_transpose
if constexpr(SimpleDequant)
{
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
sweep_tile_span(aq_spans[I0], [&](auto im) {
sweep_tile_span(aq_spans[I1], [&](auto ik) {
q_block_tensor(make_tuple(im, ik)) *=
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
});
});
}
// hot loop:
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
zero_accumulators();
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
@@ -243,9 +259,29 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
}
});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(aq_block_tensor);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
auto nIter) {
if constexpr(SimpleDequant)
{
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref += acc_val * q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
});
}
else
{
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
@@ -273,7 +309,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
});
});
}
});
});
}

View File

@@ -285,37 +285,66 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();
// Start from AQ block tensor and then scale it using BQ; this represents
// the combined A/B quantization scales for the block.
auto q_block_tensor = aq_block_tensor;
constexpr bool SimpleDequant =
Traits::NQPerBlock == 1 &&
CWarpTensor::get_distributed_spans()[I0{}].impl_.size() == 0; // c_transpose
if constexpr(SimpleDequant)
{
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
sweep_tile_span(aq_spans[I0{}], [&](auto im) {
sweep_tile_span(aq_spans[I1{}], [&](auto ik) {
q_block_tensor(make_tuple(im, ik)) *=
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
});
});
}
// hot loop:
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
auto nIter) {
CWarpTensor c_warp_tensor;
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
if constexpr(SimpleDequant)
{
constexpr auto cw_spans = CWarpTensor::get_distributed_spans();
sweep_tile_span(cw_spans[I1{}], [&](auto in) {
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
constexpr auto block_idx_n = detail::make_tile_distributed_index(
merge_sequences(sequence<nIter>{}, in.impl_));
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
constexpr auto empty_idx = tile_distributed_index<>{};
c_block_tensor(make_tuple(block_idx_m, block_idx_n)) +=
c_warp_tensor(make_tuple(empty_idx, in)) *
q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
});
}
else
{
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
@@ -387,7 +416,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
b_scale_reg_f);
});
}
});
}
});
});
}

View File

@@ -101,10 +101,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName());
// clang-format on
}
/**
* @tparam nloop The number of iterations in the hot loop,
* used to normalize scheduling costs.
*/
template <index_t nloop>
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
static_assert(nloop > 0, "nloop must be greater than 0");
// Estimated number of VMEM vector loads for A per block:
// total A bytes / (threads per block * vector width)
constexpr index_t Aload_inst =
@@ -127,12 +131,13 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
// Total VMEM load instructions (A + B + quant data)
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
// Approximate number of LDS reads per block
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle;
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle / nloop;
// Approximate number of LDS writes per block
// (e.g., writing A from VMEM into LDS once per A load)
constexpr index_t ds_write_inst = Aload_inst;
// Number of MFMA instructions per wave for one block tile:
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
constexpr index_t mfma_inst =
((kMPerBlock / WG::kM) / nloop) * ((kNPerBlock / WG::kN) / nloop);
// How often (in MFMA units) we should insert DS (LDS) operations.
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
// How often (in MFMA units) we should insert VMEM buffer loads.
@@ -169,7 +174,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
}
// Always mark some VALU work in the loop to reflect auxiliary scalar
// or vector ALU instructions that coexist with MFMA (Blockscale calculation).
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); // VALU
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -380,7 +385,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
// Prefetch A1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// initialize C
@@ -407,7 +411,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
while(iCounter > 0)
{
__builtin_amdgcn_sched_barrier(0);
// Prefill A(2i+1)
// Prefill A(2i+1) ds_write
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
@@ -435,10 +439,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// prefetch Q(2i+1)
aq_block_tile_2 = load_tile(aq_copy_dram_window);
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
bq_block_tile_2 = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
// Preload A(2i+1) ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
@@ -460,6 +468,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// prefetch Q(2i+1)
aq_block_tile = load_tile(aq_copy_dram_window);
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
bq_block_tile = load_tile(bq_copy_dram_window);
@@ -481,7 +491,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
aq_block_tile_2,
bq_block_tile_2,
a_warp_windows_pong);
// Preload A(2i+2) ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
@@ -521,7 +531,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
aq_block_tile,
bq_block_tile,
a_warp_windows_ping);
// Preload A ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;