[CK_Tile] Support for group size 128 for Preshuffle quant for 2d block scale gemm (#3462)

* formatted

* formatted

* formatting

* formatting

* formatting

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* enable prefill shapes

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* adding preshuffle quant as new parameter and its associated new files

* remove debugging statements

* adding test

* enable preshuffle quant with permuteN

* updating readme and correcponding gemmconfigs

* updating cmake file

* fixing CI failures for grouped quant gemm

* debugging permuteN

* debugging

* debugging PermuteN

* initial commit

* resolving merge conflicts

* adding test cases

* initial commit with prints

* debugging

* fine-grained working

* debugging medium grained

* fixing the tile window

* formatting

* enabling prefill shapes

* working prefill shapes

* formatted

* clean up

* code cleanup

* bug fix after merging with develop

* G128 working for both prefill and decode shapes for preshufflequant

* clean up after merging with develop

* fixing group 64 for decode shapes

* non preshufflequant working for group size 128

* enable preshuffleb and preshufflequant with variour group sizes

* reduce build time by splitting example into diff datatype files

* Adding tests for preshuffleQuant

* address review comment

* fix for gfx1201

* compile time fix for gfx1201

* clang formatted

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Agarwal <khuagarw@ctr2-alola-login-03.amd.com>
This commit is contained in:
Khushbu Agarwal
2026-01-14 10:00:19 -08:00
committed by GitHub
parent 1fc5a3f3ac
commit 118afa455c
37 changed files with 1136 additions and 681 deletions

View File

@@ -319,7 +319,23 @@ struct BQuantBlockUniversalGemmAsBsCr
if constexpr(PreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
// constexpr index_t reg_offset = nIter;
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >
(NWarp * WarpGemm::kN))
{
if constexpr(Traits::NPerBlock ==
GemmTraits::QuantGroupSize::kN)
return kQScale;
else
return nIter; // for prefill needs kQscale, for decode needs
// nIter
}
else
{
return nIter;
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;

View File

@@ -887,23 +887,27 @@ struct QuantGemmKernel
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
constexpr auto block_n =
TilePartitioner::NPerBlock /
QuantGroupSize::kN; // Number of N-dimension quantization groups per block
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(
I1); // Number of N-dimension elements per warp
constexpr auto warp_per_group =
(QuantGroupSize::kN <
warp_n) // Determine how many warps share the same scale in N-dimension
? (warp_n / QuantGroupSize::kN)
: (QuantGroupSize::kN / warp_n);
constexpr auto bqk_per_block =
TilePartitioner::KPerBlock /
QuantGroupSize::kK; // Number of K-dimension quantization groups per block
constexpr auto
tile_window_width = // The pre-shuffled layout flattens warp_n ×
// bqk_per_block scales per row, Padded up to warp_size
// to ensure coalesced memory access.
// Number of N-dimension quantization groups per block
constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / QuantGroupSize::kN
: QuantGroupSize::kN / TilePartitioner::NPerBlock;
// Number of N-dimension elements per warp
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
// Determine how many warps share the same scale in N-dimension
constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n)
? (warp_n / QuantGroupSize::kN)
: (QuantGroupSize::kN / warp_n);
// Number of K-dimension quantization groups per block
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
// The pre-shuffled layout flattens warp_n ×
// bqk_per_block scales per row, Padded up to warp_size
// to ensure coalesced memory access.
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
// Adapts based on fine vs coarse quantization granularity:
@@ -916,23 +920,42 @@ struct QuantGemmKernel
// height = block_n
constexpr auto tile_window_height =
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
auto block_n_idx =
i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a
// block index.
return make_tile_window(
bq_tensor_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_n_idx * tile_window_height, 0});
auto block_n_idx = i_n / TilePartitioner::NPerBlock;
// For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ...
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
{
block_n_idx = block_n_idx >> 1;
}
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
{
return make_tile_window(
bq_tensor_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_n_idx, 0});
}
else
{
return make_tile_window(
bq_tensor_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_n_idx * tile_window_height, 0});
}
}
else
{
constexpr auto tensor_dim =
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / QuantGroupSize::kN
: 1;
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
number<tensor_dim>{}),
{0, i_n / QuantGroupSize::kN});
}
else
@@ -940,7 +963,7 @@ struct QuantGemmKernel
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
make_tuple(number<tensor_dim>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}

View File

@@ -26,14 +26,15 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
static constexpr index_t NPerBlockBQ =
(QuantGroupSize::kN <= NPerBlock) ? NPerBlock / QuantGroupSize::kN : 1;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
static_assert(NPerBlock % QuantGroupSize::kN == 0,
"NPerBlock must be a multiple of QuantGroupSize::kN");
// static_assert(NPerBlock % QuantGroupSize::kN == 0,
// "NPerBlock must be a multiple of QuantGroupSize::kN");
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize::kK");

View File

@@ -45,7 +45,9 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t NPerBlockBQ = (Problem::QuantGroupSize::kN <= NPerBlock)
? NPerBlock / Problem::QuantGroupSize::kN
: 1;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;

View File

@@ -66,7 +66,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
(QuantGroupSize::kN <= BlockGemmShape::kN)
? integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN)
: 1;
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);

View File

@@ -240,20 +240,26 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
//
// Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256
// → 2 scales per warp in N, 2 K-groups per block
constexpr auto N1 = BlockGemmShape::kK /
KPerQ; // Number of K-dimension quantization groups per block,
// Each K-group of KPerQ elements shares the same scale.
constexpr auto N0 =
WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ
// <= WarpGemm::kN, each warp handles multiple scales.
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension
// N1: Number of K-dimension quantization groups per block,
// Each K-group of KPerQ elements shares the same scale.
// N0: Number of scales per warp in N-dimension, Since NPerQ
// <= WarpGemm::kN, each warp handles multiple scales.
// N2: Elements per thread
// NR1: Elements sharing the same scale in N-dimension
// NR0: Interleave factor to ensure full warp utilization
// K1: Number of warps distributed along this dimension
// K0: Iterations per warp to cover the K-tile
// KR: No replication in K-dimension
constexpr auto N1 = BlockGemmShape::kK / KPerQ;
constexpr auto N0 = WarpGemm::kN / NPerQ;
constexpr auto N2 = 1;
constexpr auto NR1 = NPerQ;
constexpr auto NR0 =
warp_size /
(N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization
constexpr auto K1 = NWarps; // Number of warps distributed along this dimension
constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile
constexpr auto KR = 1; // No replication in K-dimension
(warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
constexpr auto K1 = NWarps;
constexpr auto K0 = KPerTile / K1;
constexpr auto KR = 1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
@@ -275,15 +281,24 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
// Example: NPerQ=32, WarpGemm::kN=16, NWarps=4
// → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups)
constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale
constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales)
constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ)
// KR: Number of warps sharing the same scale
// K1: Number of distinct warp groups (unique scales)
// K0: Iterations to cover K-tile per warp group
// N1: K-dimension quantization groups
// N0: Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
// N2: Elements per thread
// NR1: Scale broadcast factor (full NPerQ)
// NR0: Remaining interleave factor
constexpr auto KR = NPerQ / WarpGemm::kN;
constexpr auto K1 = NWarps / KR;
constexpr auto K0 = KPerTile / K1;
constexpr auto N1 = BlockGemmShape::kK / KPerQ;
constexpr auto N0 = 1;
constexpr auto N2 = 1;
constexpr auto NR1 = NPerQ;
constexpr auto NR0 =
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
(warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
@@ -303,12 +318,19 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
//
// Example: NPerQ=128, WarpGemm::kN=16, NWarps=4
// → 128 >= 16*4=64, so all 4 warps use the same scale
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
constexpr auto N0 = 1; // Minimal (1) since scale is shared across N
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = 32; // Fixed broadcast size
// N1: K-dimension quantization groups
// N0: Minimal (1) since scale is shared across N
// N2: Elements per thread
// NR1: Fixed broadcast size
// NR0: Remaining interleave factor
constexpr auto N1 = BlockGemmShape::kK / KPerQ;
constexpr auto N0 = 1;
constexpr auto N2 = 1;
constexpr auto NR1 = 32;
constexpr auto NR0 =
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
(warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,

View File

@@ -79,10 +79,8 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
static constexpr auto TailNum = TailNum_;
static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0);
static_assert(BlockGemmShape::kN % AQuantGroupSize::kN == 0);
static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0);
static_assert(BlockGemmShape::kM % BQuantGroupSize::kM == 0);
static_assert(BlockGemmShape::kN % BQuantGroupSize::kN == 0);
static_assert(BlockGemmShape::kK % BQuantGroupSize::kK == 0);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()

View File

@@ -144,23 +144,32 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// Insert LDS read/write groups periodically based on ds_rep.
// The % pattern staggers READ and WRITE so they don't collapse
// into the same cycle in the model.
if constexpr(ds_rep > 0 && i_inst % ds_rep == 0)
if constexpr(ds_rep > 0)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
}
if constexpr(ds_rep > 0 && i_inst % ds_rep == 1)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
}
if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0)
{
if constexpr(ds_write_inst > 0)
if(i_inst % ds_rep == 0)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
}
}
if constexpr(ds_rep > 0)
{
if(i_inst % ds_rep == 1)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
}
}
if constexpr(buffer_load_rep > 0)
{
if(i_inst % buffer_load_rep == 0)
{
if constexpr(ds_write_inst > 0)
{
__builtin_amdgcn_sched_group_barrier(
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
}
}
}
// Always mark some VALU work in the loop to reflect auxiliary scalar
@@ -354,7 +363,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
@@ -431,7 +440,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
@@ -468,7 +477,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),