mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4816 (commit 17ff961)
[CK] Add split-K support for ABQuantGrouped in block_scale_gemm (#4816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Changes ### Split-K support in `gemm_quant_kernel.hpp` - **`SplitKBatchOffset`**: Added `aq_group_offset` and `aq_k_split_offset` fields (mirroring the existing `bq_*` fields for B) to track each split-K batch's position within the AQ scale tensor. For `ABQuantGrouped`, both offsets are computed from `k_id * KRead` divided by `AQuantGroupSize::kK`. - **`MakeAQBlockWindow`**: Added an `aq_group_offset` parameter (defaulting to 0 for non-split-K paths) so the AQ tensor view's K-group dimension reflects only the remaining K-groups from the split-K offset, consistent with how `MakeBQBlockWindow` handles the BQ tensor. - **`RunGemm`**: Threads the `aq_k_split_offset` through to `MakeAQBlockWindow` when in split-K mode. ### Constraints in `IsSupportedArgument()` Four constraints gate split-K (`k_batch > 1`) for ABQuantGrouped: 1. **Mode check** — split-K is only allowed for `BQuantGrouped` (no preshuffle) or `ABQuantGrouped` (no `APreshuffleQuant`). Any other quant mode with `k_batch > 1` returns `false`. 2. **B quant group alignment** — `KRead` (per-batch K slice) must be divisible by `BQuantGroupSize::kK`. Each batch must operate on complete B quantization groups; a partial group would require splitting a scale value across batches. 3. **A quant group alignment** (new, ABQuantGrouped only) — `KRead` must also be divisible by `AQuantGroupSize::kK` for the same reason applied to the AQ scale tensor. 4. **Minimum 2 K-tile iterations per batch** (new) — The software-pipelined GEMM kernels (CompV3 family) prefetch one tile ahead, so they require `per_batch_num_loop = KRead / KPerBlock >= 2`. When `KRead == KPerBlock` (i.e. each batch is exactly one tile), the prefetch reads into the next batch's memory region and produces incorrect results. Configurations where `K == k_batch * KPerBlock` are therefore rejected. ### Example update (`run_gemm_quant_example.inc`) Updated the comment above the `IsSupportedArgument` call to document that split-K is now supported for both `BQuantGrouped` (no preshuffle) and `ABQuantGrouped` (no `APreshuffleQuant`). ## Unit Tests Two new test files covering decode and prefill tile shapes across a range of `k_batch` values (2–8), data types (FP8, BF8), and quantization group sizes (1×1×128 and 1×128×128 for B): - `test_gemm_quant_abquant_splitk_decode.cpp` — uses the decode tile shape (M=16, N=64, K_tile=256) - `test_gemm_quant_abquant_splitk_prefill.cpp` — uses the prefill tile shape (M=128, N=128, K_tile=128) Each test calls `run_test_with_validation` which runs the kernel and checks correctness against a CPU reference. Configurations excluded from tests are annotated with comments explaining which constraint they violate (typically the `per_batch_num_loop >= 2` requirement). ## Prerequisites This PR depends on #4429, which must be merged before this can be merged.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6549c320fc
commit
c8a8449eec
@@ -448,18 +448,46 @@ struct QuantGemmKernel
|
||||
// offset = bq_group_offset
|
||||
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
|
||||
}
|
||||
|
||||
aq_group_offset = 0;
|
||||
aq_k_split_offset = 0;
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
|
||||
{
|
||||
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
|
||||
// Compute AQ K-group offset for this split-K batch.
|
||||
// AQ tensor layout is RowMajor [M, QK_A] with stride [stride_AQ, 1].
|
||||
// Advancing to column aq_group_offset means a pointer offset of aq_group_offset
|
||||
// elements (column stride = 1).
|
||||
const index_t k_offset_aq = amd_wave_read_first_lane(k_id * KRead);
|
||||
aq_group_offset = amd_wave_read_first_lane(k_offset_aq / AQuantGroupSize::kK);
|
||||
aq_k_split_offset = amd_wave_read_first_lane(aq_group_offset);
|
||||
|
||||
// Compute BQ K-group offset for this split-K batch.
|
||||
// BQ tensor layout is ColumnMajor [N/kN, K/kK] with stride [K/kK, 1] for
|
||||
// ABQuantGrouped. Advancing to column bq_group_offset means a pointer offset of
|
||||
// bq_group_offset elements (column stride = 1).
|
||||
const index_t k_offset_bq = amd_wave_read_first_lane(k_id * KRead);
|
||||
bq_group_offset = amd_wave_read_first_lane(k_offset_bq / BQuantGroupSize::kK);
|
||||
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
|
||||
}
|
||||
else
|
||||
{
|
||||
bq_group_offset = 0;
|
||||
bq_k_split_offset = 0;
|
||||
aq_group_offset = 0;
|
||||
aq_k_split_offset = 0;
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension)
|
||||
index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride)
|
||||
index_t aq_group_offset; // Logical offset in K-groups for AQ (K/kK dimension)
|
||||
index_t aq_k_split_offset; // Memory pointer offset for AQ
|
||||
index_t bq_group_offset; // Logical offset in K-groups for BQ (K/kK dimension)
|
||||
index_t bq_k_split_offset; // Memory pointer offset for BQ (accounting for layout/stride)
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
@@ -532,7 +560,8 @@ struct QuantGemmKernel
|
||||
CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
const index_t i_n,
|
||||
const index_t aq_group_offset = 0)
|
||||
{
|
||||
// Step 1: Create tensor view for AQ
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
@@ -615,11 +644,14 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
|
||||
{
|
||||
// For split-K, aq_ptr is already offset by aq_k_split_offset elements.
|
||||
// The remaining K-groups from this offset position = QK_A - aq_group_offset.
|
||||
const index_t remaining_qk_a = kargs.QK_A - aq_group_offset;
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.M, remaining_qk_a),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
@@ -628,9 +660,8 @@ struct QuantGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK_A),
|
||||
make_tuple(kargs.M, remaining_qk_a),
|
||||
make_tuple(1, kargs.stride_AQ),
|
||||
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
@@ -1100,26 +1131,32 @@ struct QuantGemmKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
|
||||
{
|
||||
// Split-K is supported for BQuantGrouped mode without preshuffle
|
||||
// Split-K is supported for BQuantGrouped (without preshuffle) and
|
||||
// ABQuantGrouped (without APreshuffleQuant) modes.
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
constexpr bool is_bquant_non_preshuffle =
|
||||
(kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant;
|
||||
if constexpr(!is_bquant_non_preshuffle)
|
||||
constexpr bool is_abquant_non_preshuffle =
|
||||
(kQuantType == QuantType::ABQuantGrouped) && !APreshuffleQuant;
|
||||
constexpr bool is_splitk_supported =
|
||||
is_bquant_non_preshuffle || is_abquant_non_preshuffle;
|
||||
|
||||
if constexpr(!is_splitk_supported)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 ! "
|
||||
"Split-K only supported for BQuantGrouped without preshuffle.");
|
||||
"Split-K is supported for BQuantGrouped without preshuffle "
|
||||
"and ABQuantGrouped without APreshuffleQuant.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; // per-batch K read size
|
||||
constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
@@ -1137,22 +1174,67 @@ struct QuantGemmKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
// Constraint 2: KRead must align with quantization group boundaries.
|
||||
// Each split-K batch reads KRead consecutive K elements. If KRead is not
|
||||
// a multiple of BQuantGroupSize::kK, the batch will span partial quantization
|
||||
// groups, requiring split access to a quantization scale. This violates the
|
||||
// atomic processing requirement where each batch must work with complete groups.
|
||||
if(KRead % BQuantGroupSize::kK != 0)
|
||||
// Constraint 2: KRead must align with B quantization group boundaries.
|
||||
if constexpr(is_bquant_non_preshuffle || is_abquant_non_preshuffle)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
if(KRead % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Split-K batch size must be aligned with quantization group "
|
||||
"size! KRead=" +
|
||||
std::to_string(KRead) +
|
||||
" is not divisible by BQuantGroupSize::kK=" +
|
||||
std::to_string(BQuantGroupSize::kK));
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Split-K batch size must be aligned with B quantization group "
|
||||
"size! KRead=" +
|
||||
std::to_string(KRead) +
|
||||
" is not divisible by BQuantGroupSize::kK=" +
|
||||
std::to_string(BQuantGroupSize::kK));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint 3: KRead must align with A quantization group boundaries
|
||||
// (only needed for ABQuantGrouped since AQ also indexes into K).
|
||||
if constexpr(is_abquant_non_preshuffle)
|
||||
{
|
||||
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
if(KRead % AQuantGroupSize::kK != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Split-K batch size must be aligned with A quantization group "
|
||||
"size! KRead=" +
|
||||
std::to_string(KRead) +
|
||||
" is not divisible by AQuantGroupSize::kK=" +
|
||||
std::to_string(AQuantGroupSize::kK));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint 4: per-batch K must span at least 2 K_Tile iterations.
|
||||
// The software-pipelined GEMM kernels (CompV3 family) prefetch one tile
|
||||
// ahead and require num_loop >= 2 per batch. When KRead == KPerBlock
|
||||
// (i.e. per_batch_num_loop == 1) the prefetch would read the tile
|
||||
// belonging to the next split-K batch, producing incorrect results.
|
||||
{
|
||||
const index_t per_batch_num_loop =
|
||||
KRead / static_cast<index_t>(TilePartitioner::KPerBlock);
|
||||
if(per_batch_num_loop < 2)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Split-K requires at least 2 K-tile iterations per batch. "
|
||||
"KRead=" +
|
||||
std::to_string(KRead) + " < 2 * KPerBlock=" +
|
||||
std::to_string(2 *
|
||||
static_cast<index_t>(TilePartitioner::KPerBlock)) +
|
||||
". Increase K or decrease k_batch.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1243,6 +1325,18 @@ struct QuantGemmKernel
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// For RowMajor C, M is the row dimension — check M alignment here because
|
||||
// ALayout=RowMajor does not check M (it only checks K), leaving a gap for
|
||||
// the RowMajorA + RowMajorC combination.
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
@@ -1315,7 +1409,10 @@ struct QuantGemmKernel
|
||||
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
// Note: Pass aq_group_offset so the tensor view dimension reflects
|
||||
// the remaining K-groups from the split-K offset position.
|
||||
const auto& aq_block_window = MakeAQBlockWindow(
|
||||
aq_ptr, kargs, block_idx_m, block_idx_n, splitk_batch_offset.aq_group_offset);
|
||||
// Note: Pass bq_group_offset so the tensor view dimension reflects
|
||||
// the remaining K-groups from the split-K offset position.
|
||||
const auto& bq_block_window = MakeBQBlockWindow(
|
||||
@@ -1445,7 +1542,10 @@ struct QuantGemmKernel
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
|
||||
// For ABQuantGrouped split-K, aq_ptr is offset by aq_k_split_offset elements to point
|
||||
// to the start of this batch's AQ K-groups (aq_group_offset columns in RowMajor AQ).
|
||||
const AQDataType* aq_ptr =
|
||||
static_cast<const AQDataType*>(kargs.aq_ptr) + splitk_batch_offset.aq_k_split_offset;
|
||||
const BQDataType* bq_ptr =
|
||||
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
Reference in New Issue
Block a user