[CK_TILE] Grouped gemm quant tensor layouts (#3414)

* feat: add RRR, CRR, CCR layouts for a/b quant grouped gemm tests and examples. Refactor example setup to improve compile time

* chore: split out bquant preshuffle test, and reduce tile size to 128 to temporarily solve slow compile times

* chore: set m/n warp tile to 16 as configurations with 32 seem to have some support problems

* fix: missing check for transposed load in bquant pipeline

* chore: lower unit test tensors dimensions a bit for faster tests

* chore: set grouped gemm example M/N warp tile to 16

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Erwin Terpstra
2025-12-25 08:01:23 +01:00
committed by GitHub
parent 14668a56e3
commit e08efa551f
20 changed files with 662 additions and 490 deletions

View File

@@ -422,7 +422,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
currIdx = (currIdx + 1) % 2;
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -433,7 +433,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
// Note: BDataType gets converted during loading from PkInt4
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(