[rocm-libraries] ROCm/rocm-libraries#4816 (commit 17ff961)

[CK] Add split-K support for ABQuantGrouped in
 block_scale_gemm (#4816)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Changes

### Split-K support in `gemm_quant_kernel.hpp`

- **`SplitKBatchOffset`**: Added `aq_group_offset` and
`aq_k_split_offset` fields (mirroring the existing `bq_*` fields for B)
to track each split-K batch's position within the AQ scale tensor. For
`ABQuantGrouped`, both offsets are computed from `k_id * KRead` divided
by `AQuantGroupSize::kK`.

- **`MakeAQBlockWindow`**: Added an `aq_group_offset` parameter
(defaulting to 0 for non-split-K paths) so the AQ tensor view's K-group
dimension reflects only the remaining K-groups from the split-K offset,
consistent with how `MakeBQBlockWindow` handles the BQ tensor.

- **`RunGemm`**: Threads the `aq_k_split_offset` through to
`MakeAQBlockWindow` when in split-K mode.

### Constraints in `IsSupportedArgument()`

Four constraints gate split-K (`k_batch > 1`) for ABQuantGrouped:

1. **Mode check** — split-K is only allowed for `BQuantGrouped` (no
preshuffle) or `ABQuantGrouped` (no `APreshuffleQuant`). Any other quant
mode with `k_batch > 1` returns `false`.

2. **B quant group alignment** — `KRead` (per-batch K slice) must be
divisible by `BQuantGroupSize::kK`. Each batch must operate on complete
B quantization groups; a partial group would require splitting a scale
value across batches.

3. **A quant group alignment** (new, ABQuantGrouped only) — `KRead` must
also be divisible by `AQuantGroupSize::kK` for the same reason applied
to the AQ scale tensor.

4. **Minimum 2 K-tile iterations per batch** (new) — The
software-pipelined GEMM kernels (CompV3 family) prefetch one tile ahead,
so they require `per_batch_num_loop = KRead / KPerBlock >= 2`. When
`KRead == KPerBlock` (i.e. each batch is exactly one tile), the prefetch
reads into the next batch's memory region and produces incorrect
results. Configurations where `K == k_batch * KPerBlock` are therefore
rejected.

### Example update (`run_gemm_quant_example.inc`)

Updated the comment above the `IsSupportedArgument` call to document
that split-K is now supported for both `BQuantGrouped` (no preshuffle)
and `ABQuantGrouped` (no `APreshuffleQuant`).

## Unit Tests

Two new test files covering decode and prefill tile shapes across a
range of `k_batch` values (2–8), data types (FP8, BF8), and quantization
group sizes (1×1×128 and 1×128×128 for B):

- `test_gemm_quant_abquant_splitk_decode.cpp` — uses the decode tile
shape (M=16, N=64, K_tile=256)
- `test_gemm_quant_abquant_splitk_prefill.cpp` — uses the prefill tile
shape (M=128, N=128, K_tile=128)

Each test calls `run_test_with_validation` which runs the kernel and
checks correctness against a CPU reference. Configurations excluded from
tests are annotated with comments explaining which constraint they
violate (typically the `per_batch_num_loop >= 2` requirement).

## Prerequisites

This PR depends on #4429, which must be merged before this can be
merged.
This commit is contained in:
Aviral Goel
2026-02-26 23:57:17 +00:00
committed by assistant-librarian[bot]
parent 6549c320fc
commit c8a8449eec
23 changed files with 796 additions and 418 deletions

View File

@@ -19,13 +19,13 @@ template <typename TileDistributedSpan_, // tile_distributed_span<...>
>
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
{
using DstrSpanImpl = typename remove_cvref_t<TileDistributedSpan_>::Impl;
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
if constexpr(DstrSpanImpl::size() == 0) // handle the 0-dim span case
f(detail::make_tile_distributed_index(sequence<>{}));
else
static_ford<DstrSpanImpl>{}(
[&](auto dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)); });
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
f(dstr_idx);
});
}
// unpacked span, this version support span with unpack(multi-arg) functor

View File

@@ -15,7 +15,7 @@ namespace ck_tile {
// B is block window on block distributed tensor.
// C is block distributed tensor
template <typename Problem_, typename BlockPolicy_>
struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase
{
private:
template <typename PipelineProblem_, typename GemmPolicy_>
@@ -121,6 +121,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
};
public:
using Base = BlockGemmQuantBase;
using Traits = GemmTraits_<Problem_, BlockPolicy_>;
using Problem = remove_cvref_t<Problem_>;
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
@@ -217,22 +218,6 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
});
});
};
auto q_block_tensor = aq_block_tensor;
constexpr bool SimpleDequant =
Traits::NQPerBlock == 1 &&
AccTensor::get_distributed_spans()[I0].impl_.size() == 0; // c_transpose
if constexpr(SimpleDequant)
{
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
sweep_tile_span(aq_spans[I0], [&](auto im) {
sweep_tile_span(aq_spans[I1], [&](auto ik) {
q_block_tensor(make_tuple(im, ik)) *=
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
});
});
}
// hot loop:
static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
zero_accumulators();
static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
@@ -265,29 +250,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
}
});
});
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
auto nIter) {
if constexpr(SimpleDequant)
{
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref += acc_val * q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
});
}
else
{
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(aq_block_tensor);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
@@ -305,9 +270,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
return nIter * KPerBlockBQ + kQScale;
}
}();
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f =
aq_picker.template cvt_scale_to_fp32<BQDataType>(scale_reg);
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_reg_f = Base::cvt_scale_to_fp32<BQDataType>(scale_reg);
static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
float a_scale_reg_f = aq_picker.template pick<c_row>();
@@ -315,7 +279,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f;
});
}
});
});
});
}

View File

@@ -291,66 +291,37 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();
// Start from AQ block tensor and then scale it using BQ; this represents
// the combined A/B quantization scales for the block.
auto q_block_tensor = aq_block_tensor;
constexpr bool SimpleDequant =
Traits::NQPerBlock == 1 &&
CWarpTensor::get_distributed_spans()[I0{}].impl_.size() == 0; // c_transpose
if constexpr(SimpleDequant)
{
constexpr auto aq_spans = AQBlockTensor::get_distributed_spans();
sweep_tile_span(aq_spans[I0{}], [&](auto im) {
sweep_tile_span(aq_spans[I1{}], [&](auto ik) {
q_block_tensor(make_tuple(im, ik)) *=
bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik));
});
});
}
// hot loop:
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for_product<number<MIterPerWarp>, number<NIterPerWarp>>{}([&](auto mIter,
auto nIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
CWarpTensor c_warp_tensor;
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
if constexpr(SimpleDequant)
{
constexpr auto cw_spans = CWarpTensor::get_distributed_spans();
sweep_tile_span(cw_spans[I1{}], [&](auto in) {
constexpr auto block_idx_m = tile_distributed_index<mIter>{};
constexpr auto block_idx_n = detail::make_tile_distributed_index(
merge_sequences(sequence<nIter>{}, in.impl_));
constexpr auto block_idx_kq = tile_distributed_index<kQScale>{};
constexpr auto empty_idx = tile_distributed_index<>{};
c_block_tensor(make_tuple(block_idx_m, block_idx_n)) +=
c_warp_tensor(make_tuple(empty_idx, in)) *
q_block_tensor(make_tuple(block_idx_m, block_idx_kq));
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
});
}
else
{
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
@@ -435,7 +406,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
b_scale_reg_f);
});
}
}
});
});
});
}

View File

@@ -448,18 +448,46 @@ struct QuantGemmKernel
// offset = bq_group_offset
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
}
aq_group_offset = 0;
aq_k_split_offset = 0;
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
{
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
// Compute AQ K-group offset for this split-K batch.
// AQ tensor layout is RowMajor [M, QK_A] with stride [stride_AQ, 1].
// Advancing to column aq_group_offset means a pointer offset of aq_group_offset
// elements (column stride = 1).
const index_t k_offset_aq = amd_wave_read_first_lane(k_id * KRead);
aq_group_offset = amd_wave_read_first_lane(k_offset_aq / AQuantGroupSize::kK);
aq_k_split_offset = amd_wave_read_first_lane(aq_group_offset);
// Compute BQ K-group offset for this split-K batch.
// BQ tensor layout is ColumnMajor [N/kN, K/kK] with stride [K/kK, 1] for
// ABQuantGrouped. Advancing to column bq_group_offset means a pointer offset of
// bq_group_offset elements (column stride = 1).
const index_t k_offset_bq = amd_wave_read_first_lane(k_id * KRead);
bq_group_offset = amd_wave_read_first_lane(k_offset_bq / BQuantGroupSize::kK);
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
}
else
{
bq_group_offset = 0;
bq_k_split_offset = 0;
aq_group_offset = 0;
aq_k_split_offset = 0;
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension)
index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride)
index_t aq_group_offset; // Logical offset in K-groups for AQ (K/kK dimension)
index_t aq_k_split_offset; // Memory pointer offset for AQ
index_t bq_group_offset; // Logical offset in K-groups for BQ (K/kK dimension)
index_t bq_k_split_offset; // Memory pointer offset for BQ (accounting for layout/stride)
index_t splitted_k;
};
@@ -532,7 +560,8 @@ struct QuantGemmKernel
CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
const QuantGemmKernelArgs& kargs,
const index_t i_m,
const index_t i_n)
const index_t i_n,
const index_t aq_group_offset = 0)
{
// Step 1: Create tensor view for AQ
const auto& aq_tensor_view = [&]() {
@@ -615,11 +644,14 @@ struct QuantGemmKernel
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
{
// For split-K, aq_ptr is already offset by aq_k_split_offset elements.
// The remaining K-groups from this offset position = QK_A - aq_group_offset.
const index_t remaining_qk_a = kargs.QK_A - aq_group_offset;
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.M, remaining_qk_a),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
@@ -628,9 +660,8 @@ struct QuantGemmKernel
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.M, remaining_qk_a),
make_tuple(1, kargs.stride_AQ),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
@@ -1100,26 +1131,32 @@ struct QuantGemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
{
// Split-K is supported for BQuantGrouped mode without preshuffle
// Split-K is supported for BQuantGrouped (without preshuffle) and
// ABQuantGrouped (without APreshuffleQuant) modes.
if(kargs.k_batch != 1)
{
constexpr bool is_bquant_non_preshuffle =
(kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant;
if constexpr(!is_bquant_non_preshuffle)
constexpr bool is_abquant_non_preshuffle =
(kQuantType == QuantType::ABQuantGrouped) && !APreshuffleQuant;
constexpr bool is_splitk_supported =
is_bquant_non_preshuffle || is_abquant_non_preshuffle;
if constexpr(!is_splitk_supported)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 ! "
"Split-K only supported for BQuantGrouped without preshuffle.");
"Split-K is supported for BQuantGrouped without preshuffle "
"and ABQuantGrouped without APreshuffleQuant.");
}
return false;
}
else
{
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; // per-batch K read size
constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
@@ -1137,22 +1174,67 @@ struct QuantGemmKernel
return false;
}
// Constraint 2: KRead must align with quantization group boundaries.
// Each split-K batch reads KRead consecutive K elements. If KRead is not
// a multiple of BQuantGroupSize::kK, the batch will span partial quantization
// groups, requiring split access to a quantization scale. This violates the
// atomic processing requirement where each batch must work with complete groups.
if(KRead % BQuantGroupSize::kK != 0)
// Constraint 2: KRead must align with B quantization group boundaries.
if constexpr(is_bquant_non_preshuffle || is_abquant_non_preshuffle)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
if(KRead % BQuantGroupSize::kK != 0)
{
CK_TILE_ERROR("Split-K batch size must be aligned with quantization group "
"size! KRead=" +
std::to_string(KRead) +
" is not divisible by BQuantGroupSize::kK=" +
std::to_string(BQuantGroupSize::kK));
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Split-K batch size must be aligned with B quantization group "
"size! KRead=" +
std::to_string(KRead) +
" is not divisible by BQuantGroupSize::kK=" +
std::to_string(BQuantGroupSize::kK));
}
return false;
}
}
// Constraint 3: KRead must align with A quantization group boundaries
// (only needed for ABQuantGrouped since AQ also indexes into K).
if constexpr(is_abquant_non_preshuffle)
{
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
if(KRead % AQuantGroupSize::kK != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Split-K batch size must be aligned with A quantization group "
"size! KRead=" +
std::to_string(KRead) +
" is not divisible by AQuantGroupSize::kK=" +
std::to_string(AQuantGroupSize::kK));
}
return false;
}
}
// Constraint 4: per-batch K must span at least 2 K_Tile iterations.
// The software-pipelined GEMM kernels (CompV3 family) prefetch one tile
// ahead and require num_loop >= 2 per batch. When KRead == KPerBlock
// (i.e. per_batch_num_loop == 1) the prefetch would read the tile
// belonging to the next split-K batch, producing incorrect results.
{
const index_t per_batch_num_loop =
KRead / static_cast<index_t>(TilePartitioner::KPerBlock);
if(per_batch_num_loop < 2)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Split-K requires at least 2 K-tile iterations per batch. "
"KRead=" +
std::to_string(KRead) + " < 2 * KPerBlock=" +
std::to_string(2 *
static_cast<index_t>(TilePartitioner::KPerBlock)) +
". Increase K or decrease k_batch.");
}
return false;
}
return false;
}
}
}
@@ -1243,6 +1325,18 @@ struct QuantGemmKernel
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
// For RowMajor C, M is the row dimension — check M alignment here because
// ALayout=RowMajor does not check M (it only checks K), leaving a gap for
// the RowMajorA + RowMajorC combination.
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support M that is not a multiple of MPerBlock without padding!");
}
return false;
}
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
@@ -1315,7 +1409,10 @@ struct QuantGemmKernel
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
// Note: Pass aq_group_offset so the tensor view dimension reflects
// the remaining K-groups from the split-K offset position.
const auto& aq_block_window = MakeAQBlockWindow(
aq_ptr, kargs, block_idx_m, block_idx_n, splitk_batch_offset.aq_group_offset);
// Note: Pass bq_group_offset so the tensor view dimension reflects
// the remaining K-groups from the split-K offset position.
const auto& bq_block_window = MakeBQBlockWindow(
@@ -1445,7 +1542,10 @@ struct QuantGemmKernel
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
// For ABQuantGrouped split-K, aq_ptr is offset by aq_k_split_offset elements to point
// to the start of this batch's AQ K-groups (aq_group_offset columns in RowMajor AQ).
const AQDataType* aq_ptr =
static_cast<const AQDataType*>(kargs.aq_ptr) + splitk_batch_offset.aq_k_split_offset;
const BQDataType* bq_ptr =
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);

View File

@@ -108,14 +108,10 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName());
// clang-format on
}
/**
* @tparam nloop The number of iterations in the hot loop,
* used to normalize scheduling costs.
*/
template <index_t nloop>
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
static_assert(nloop > 0, "nloop must be greater than 0");
// Estimated number of VMEM vector loads for A per block:
// total A bytes / (threads per block * vector width)
constexpr index_t Aload_inst =
@@ -138,13 +134,12 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
// Total VMEM load instructions (A + B + quant data)
constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst;
// Approximate number of LDS reads per block
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle / nloop;
constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle;
// Approximate number of LDS writes per block
// (e.g., writing A from VMEM into LDS once per A load)
constexpr index_t ds_write_inst = Aload_inst;
// Number of MFMA instructions per wave for one block tile:
constexpr index_t mfma_inst =
((kMPerBlock / WG::kM) / nloop) * ((kNPerBlock / WG::kN) / nloop);
constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN);
// How often (in MFMA units) we should insert DS (LDS) operations.
constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst);
// How often (in MFMA units) we should insert VMEM buffer loads.
@@ -181,7 +176,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
}
// Always mark some VALU work in the loop to reflect auxiliary scalar
// or vector ALU instructions that coexist with MFMA (Blockscale calculation).
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); // VALU
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU
});
});
__builtin_amdgcn_sched_barrier(0);
@@ -409,6 +404,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
// Prefetch A1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// initialize C
@@ -437,7 +433,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
while(iCounter > 0)
{
__builtin_amdgcn_sched_barrier(0);
// Prefill A(2i+1) ds_write
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
@@ -465,14 +461,10 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// prefetch Q(2i+1)
aq_block_tile_2 = load_tile(aq_copy_dram_window);
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
bq_block_tile_2 = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step);
// Preload A(2i+1) ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
@@ -494,8 +486,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// prefetch Q(2i+1)
aq_block_tile = load_tile(aq_copy_dram_window);
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
bq_block_tile = load_tile(bq_copy_dram_window);
@@ -517,7 +507,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
aq_block_tile_2,
bq_block_tile_2,
a_warp_windows_pong);
// Preload A(2i+2) ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
@@ -557,7 +547,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
aq_block_tile,
bq_block_tile,
a_warp_windows_ping);
// Preload A ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;