[CK Tile] Grouped GEMM aquant mode and non-persistent kernel (#3337)

* wip: add aquant to grouped gemm quant example

* fix: properly handle hot loop count in aquant pipeline

* fix: add separate GemmConfig structs for AQuant, automatically select the correct one

* feat: finish support for a non-persistent kernel invocation for grouped gemm quant, and add support code to example

* refactor: cleaned up grouped gemm quant example a bit by reusing pipeline selection logic

* chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants

* feat: add quant grouped gemm tests cases for aquant (regular and transpose C) and non-persistent kernel

* fix: update base pipeline classes according to changes in develop branch

* Revert "chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants"

This reverts commit b3fd4d326d.

* feat: remove aquant config from grouped gemm quant example, update to add persistency as runtime parameter

* chore: removed work-around for aquant bug that has been fixed

* chore: fix typo in command-line parameters

* fix: correct K warp tile size for gfx950

* chore: incorrect warp tile configuration on gfx942
This commit is contained in:
Erwin Terpstra
2025-12-08 21:19:22 +01:00
committed by GitHub
parent ca6143f0b2
commit fe07b5a1bf
12 changed files with 948 additions and 206 deletions

View File

@@ -163,7 +163,6 @@ struct QuantGroupedGemmKernel
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -262,10 +261,9 @@ struct QuantGroupedGemmKernel
auto karg =
QuantGroupedGemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
type_convert<CDataType*>(gemm_descs[i].e_ptr),
type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
gemm_descs[i].k_batch,
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
@@ -275,7 +273,8 @@ struct QuantGroupedGemmKernel
stride_b,
stride_e,
gemm_descs[i].stride_AQ,
gemm_descs[i].stride_BQ};
gemm_descs[i].stride_BQ,
gemm_descs[i].k_batch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
@@ -342,16 +341,32 @@ struct QuantGroupedGemmKernel
else
{
RunGemmWithPipelineSelection(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else // Non-persistent kernel
{
Base::RunGemm({a_ptr},
{b_ptr},
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
}
@@ -451,7 +466,24 @@ struct QuantGroupedGemmKernel
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
if constexpr(kQuantType == QuantType::BQuantGrouped)
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
aq_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
auto& c_block_window = gemm_tile_windows.at(Base::I4);
// Run Epilogue Pipeline
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
// Run GEMM pipeline
@@ -496,6 +528,53 @@ struct QuantGroupedGemmKernel
}
}
CK_TILE_DEVICE index_t FindGroupId(const QuantGemmTransKernelArg* gemm_desc_ptr,
index_t block_id,
index_t group_count) const
{
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) >> 1);
while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
block_id < gemm_desc_ptr[group_id].block_end)) &&
left <= right)
{
if(block_id < gemm_desc_ptr[group_id].block_start)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) >> 1);
}
return group_id;
}
// For non-persistent kernels
template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
index_t group_count) const
{
const index_t block_id = ck_tile::get_block_1d_id();
const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0,
kargs.group_karg.M,
kargs.group_karg.N,
(block_id - kargs.block_start) % grid_size_2d);
Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
}
// For persistent kernels
template <bool U = UsePersistentKernel,
typename = std::enable_if_t<U>,

View File

@@ -319,6 +319,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
if constexpr(HasHotLoop)
{
constexpr index_t tail_count =
((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2;
index_t i = 0;
do
{
@@ -366,7 +368,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
} while(i < (num_loop - tail_count));
}
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
@@ -439,6 +441,51 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
num_loop,
p_smem);
}
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
///
/// This operator is used by grouped GEMM kernels where pipeline parameters
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
/// at runtime, not on the host side during compilation. This is necessary
/// because different GEMM problems in the group may have different K dimensions,
/// requiring different pipeline configurations that cannot be determined at
/// compile time.
///
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
/// @param aq_dram_block_window_tmp Block window for AQ (quantization scale) tensor in DRAM
/// @param num_loop Number of main loop iterations (calculated on device)
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
/// @param tail_number Type of tail handling required (calculated on device)
/// @param p_smem Pointer to shared memory
/// @return Accumulated result tile in registers
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem,
index_t m = 0) const
{
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
constexpr bool hot_loop = has_hot_loop_.value;
constexpr auto tail_num = tail_number_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,
m, // dummy value, won't be used
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile