[rocm-libraries] ROCm/rocm-libraries#4816 (commit 17ff961)

[CK] Add split-K support for ABQuantGrouped in
 block_scale_gemm (#4816)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Changes

### Split-K support in `gemm_quant_kernel.hpp`

- **`SplitKBatchOffset`**: Added `aq_group_offset` and
`aq_k_split_offset` fields (mirroring the existing `bq_*` fields for B)
to track each split-K batch's position within the AQ scale tensor. For
`ABQuantGrouped`, both offsets are computed from `k_id * KRead` divided
by `AQuantGroupSize::kK`.

- **`MakeAQBlockWindow`**: Added an `aq_group_offset` parameter
(defaulting to 0 for non-split-K paths) so the AQ tensor view's K-group
dimension reflects only the remaining K-groups from the split-K offset,
consistent with how `MakeBQBlockWindow` handles the BQ tensor.

- **`RunGemm`**: Threads the `aq_k_split_offset` through to
`MakeAQBlockWindow` when in split-K mode.

### Constraints in `IsSupportedArgument()`

Four constraints gate split-K (`k_batch > 1`) for ABQuantGrouped:

1. **Mode check** — split-K is only allowed for `BQuantGrouped` (no
preshuffle) or `ABQuantGrouped` (no `APreshuffleQuant`). Any other quant
mode with `k_batch > 1` returns `false`.

2. **B quant group alignment** — `KRead` (per-batch K slice) must be
divisible by `BQuantGroupSize::kK`. Each batch must operate on complete
B quantization groups; a partial group would require splitting a scale
value across batches.

3. **A quant group alignment** (new, ABQuantGrouped only) — `KRead` must
also be divisible by `AQuantGroupSize::kK` for the same reason applied
to the AQ scale tensor.

4. **Minimum 2 K-tile iterations per batch** (new) — The
software-pipelined GEMM kernels (CompV3 family) prefetch one tile ahead,
so they require `per_batch_num_loop = KRead / KPerBlock >= 2`. When
`KRead == KPerBlock` (i.e. each batch is exactly one tile), the prefetch
reads into the next batch's memory region and produces incorrect
results. Configurations where `K == k_batch * KPerBlock` are therefore
rejected.

### Example update (`run_gemm_quant_example.inc`)

Updated the comment above the `IsSupportedArgument` call to document
that split-K is now supported for both `BQuantGrouped` (no preshuffle)
and `ABQuantGrouped` (no `APreshuffleQuant`).

## Unit Tests

Two new test files covering decode and prefill tile shapes across a
range of `k_batch` values (2–8), data types (FP8, BF8), and quantization
group sizes (1×1×128 and 1×128×128 for B):

- `test_gemm_quant_abquant_splitk_decode.cpp` — uses the decode tile
shape (M=16, N=64, K_tile=256)
- `test_gemm_quant_abquant_splitk_prefill.cpp` — uses the prefill tile
shape (M=128, N=128, K_tile=128)

Each test calls `run_test_with_validation` which runs the kernel and
checks correctness against a CPU reference. Configurations excluded from
tests are annotated with comments explaining which constraint they
violate (typically the `per_batch_num_loop >= 2` requirement).

## Prerequisites

This PR depends on #4429, which must be merged before this can be
merged.
This commit is contained in:
Aviral Goel
2026-02-26 23:57:17 +00:00
committed by assistant-librarian[bot]
parent 6549c320fc
commit c8a8449eec
23 changed files with 796 additions and 418 deletions

View File

@@ -108,14 +108,10 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName());
// clang-format on
}
/**
* @tparam nloop The number of iterations in the hot loop,
* used to normalize scheduling costs.
*/
template <index_t nloop>
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
static_assert(nloop > 0, "nloop must be greater than 0");
// Estimated number of VMEM vector loads for A per block:
// total A bytes / (threads per block * vector width)
constexpr index_t Aload_inst =
@@ -138,13 +134,12 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
// Total VMEM load instructions (A + B + quant data)
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
// Approximate number of LDS reads per block
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle / nloop;
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle;
// Approximate number of LDS writes per block
// (e.g., writing A from VMEM into LDS once per A load)
constexpr index_t ds_write_inst = Aload_inst;
// Number of MFMA instructions per wave for one block tile:
constexpr index_t mfma_inst =
((kMPerBlock / WG::kM) / nloop) * ((kNPerBlock / WG::kN) / nloop);
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
// How often (in MFMA units) we should insert DS (LDS) operations.
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
// How often (in MFMA units) we should insert VMEM buffer loads.
@@ -181,7 +176,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
}
// Always mark some VALU work in the loop to reflect auxiliary scalar
// or vector ALU instructions that coexist with MFMA (Blockscale calculation).
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); // VALU
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -409,6 +404,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
// Prefetch A1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// initialize C
@@ -437,7 +433,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
while(iCounter > 0)
{
__builtin_amdgcn_sched_barrier(0);
// Prefill A(2i+1) ds_write
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
@@ -465,14 +461,10 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// prefetch Q(2i+1)
aq_block_tile_2 = load_tile(aq_copy_dram_window);
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
bq_block_tile_2 = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step);
// Preload A(2i+1) ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
@@ -494,8 +486,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// prefetch Q(2i+1)
aq_block_tile = load_tile(aq_copy_dram_window);
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
bq_block_tile = load_tile(bq_copy_dram_window);
@@ -517,7 +507,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
aq_block_tile_2,
bq_block_tile_2,
a_warp_windows_pong);
// Preload A(2i+2) ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
@@ -557,7 +547,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
aq_block_tile,
bq_block_tile,
a_warp_windows_ping);
// Preload A ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;