mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4816 (commit 17ff961)
[CK] Add split-K support for ABQuantGrouped in block_scale_gemm (#4816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Changes ### Split-K support in `gemm_quant_kernel.hpp` - **`SplitKBatchOffset`**: Added `aq_group_offset` and `aq_k_split_offset` fields (mirroring the existing `bq_*` fields for B) to track each split-K batch's position within the AQ scale tensor. For `ABQuantGrouped`, both offsets are computed from `k_id * KRead` divided by `AQuantGroupSize::kK`. - **`MakeAQBlockWindow`**: Added an `aq_group_offset` parameter (defaulting to 0 for non-split-K paths) so the AQ tensor view's K-group dimension reflects only the remaining K-groups from the split-K offset, consistent with how `MakeBQBlockWindow` handles the BQ tensor. - **`RunGemm`**: Threads the `aq_k_split_offset` through to `MakeAQBlockWindow` when in split-K mode. ### Constraints in `IsSupportedArgument()` Four constraints gate split-K (`k_batch > 1`) for ABQuantGrouped: 1. **Mode check** — split-K is only allowed for `BQuantGrouped` (no preshuffle) or `ABQuantGrouped` (no `APreshuffleQuant`). Any other quant mode with `k_batch > 1` returns `false`. 2. **B quant group alignment** — `KRead` (per-batch K slice) must be divisible by `BQuantGroupSize::kK`. Each batch must operate on complete B quantization groups; a partial group would require splitting a scale value across batches. 3. **A quant group alignment** (new, ABQuantGrouped only) — `KRead` must also be divisible by `AQuantGroupSize::kK` for the same reason applied to the AQ scale tensor. 4. **Minimum 2 K-tile iterations per batch** (new) — The software-pipelined GEMM kernels (CompV3 family) prefetch one tile ahead, so they require `per_batch_num_loop = KRead / KPerBlock >= 2`. When `KRead == KPerBlock` (i.e. each batch is exactly one tile), the prefetch reads into the next batch's memory region and produces incorrect results. Configurations where `K == k_batch * KPerBlock` are therefore rejected. ### Example update (`run_gemm_quant_example.inc`) Updated the comment above the `IsSupportedArgument` call to document that split-K is now supported for both `BQuantGrouped` (no preshuffle) and `ABQuantGrouped` (no `APreshuffleQuant`). ## Unit Tests Two new test files covering decode and prefill tile shapes across a range of `k_batch` values (2–8), data types (FP8, BF8), and quantization group sizes (1×1×128 and 1×128×128 for B): - `test_gemm_quant_abquant_splitk_decode.cpp` — uses the decode tile shape (M=16, N=64, K_tile=256) - `test_gemm_quant_abquant_splitk_prefill.cpp` — uses the prefill tile shape (M=128, N=128, K_tile=128) Each test calls `run_test_with_validation` which runs the kernel and checks correctness against a CPU reference. Configurations excluded from tests are annotated with comments explaining which constraint they violate (typically the `per_batch_num_loop >= 2` requirement). ## Prerequisites This PR depends on #4429, which must be merged before this can be merged.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6549c320fc
commit
c8a8449eec
@@ -15,7 +15,7 @@ namespace ck_tile {
|
||||
// B is block window on block distributed tensor.
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename BlockPolicy_>
|
||||
struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
@@ -121,6 +121,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
};
|
||||
|
||||
public:
|
||||
using Base = BlockGemmQuantBase;
|
||||
using Traits = GemmTraits_<Problem_, BlockPolicy_>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
|
||||
@@ -217,22 +218,6 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto q_block_tensor = aq_block_tensor;
|
||||
constexpr bool SimpleDequant =
|
||||
Traits::NQPerBlock == 1 &&
|
||||
AccTensor::get_distributed_spans()[I0].impl_.size() == 0; // c_transpose
|
||||
if constexpr(SimpleDequant)
|
||||
{
|
||||
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
|
||||
sweep_tile_span(aq_spans[I0], [&](auto im) {
|
||||
sweep_tile_span(aq_spans[I1], [&](auto ik) {
|
||||
q_block_tensor(make_tuple(im, ik)) *=
|
||||
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
|
||||
});
|
||||
});
|
||||
}
|
||||
// hot loop:
|
||||
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
zero_accumulators();
|
||||
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
@@ -265,29 +250,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
}
|
||||
});
|
||||
});
|
||||
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
|
||||
auto nIter) {
|
||||
if constexpr(SimpleDequant)
|
||||
{
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
|
||||
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref += acc_val * q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(aq_block_tensor);
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
@@ -305,9 +270,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
return nIter * KPerBlockBQ + kQScale;
|
||||
}
|
||||
}();
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_reg_f =
|
||||
aq_picker.template cvt_scale_to_fp32<BQDataType>(scale_reg);
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
float b_scale_reg_f = Base::cvt_scale_to_fp32<BQDataType>(scale_reg);
|
||||
|
||||
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
|
||||
float a_scale_reg_f = aq_picker.template pick<c_row>();
|
||||
@@ -315,7 +279,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
|
||||
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -291,66 +291,37 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
"C block tensor data type!");
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// Start from AQ block tensor and then scale it using BQ; this represents
|
||||
// the combined A/B quantization scales for the block.
|
||||
auto q_block_tensor = aq_block_tensor;
|
||||
constexpr bool SimpleDequant =
|
||||
Traits::NQPerBlock == 1 &&
|
||||
CWarpTensor::get_distributed_spans()[I0{}].impl_.size() == 0; // c_transpose
|
||||
if constexpr(SimpleDequant)
|
||||
{
|
||||
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
|
||||
sweep_tile_span(aq_spans[I0{}], [&](auto im) {
|
||||
sweep_tile_span(aq_spans[I1{}], [&](auto ik) {
|
||||
q_block_tensor(make_tuple(im, ik)) *=
|
||||
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// hot loop:
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
|
||||
auto nIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
if constexpr(SimpleDequant)
|
||||
{
|
||||
constexpr auto cw_spans = CWarpTensor::get_distributed_spans();
|
||||
sweep_tile_span(cw_spans[I1{}], [&](auto in) {
|
||||
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
|
||||
constexpr auto block_idx_n = detail::make_tile_distributed_index(
|
||||
merge_sequences(sequence<nIter>{}, in.impl_));
|
||||
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
|
||||
constexpr auto empty_idx = tile_distributed_index<>{};
|
||||
c_block_tensor(make_tuple(block_idx_m, block_idx_n)) +=
|
||||
c_warp_tensor(make_tuple(empty_idx, in)) *
|
||||
q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
@@ -435,7 +406,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
b_scale_reg_f);
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user