[CK_Tile] Support for various group sizes Preshuffle quant for 2d block scale gemm (#3445)

* formatted

* formatted

* formatting

* formatting

* formatting

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* enable prefill shapes

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* adding preshuffle quant as new parameter and its associated new files

* remove debugging statements

* adding test

* enable preshuffle quant with permuteN

* updating readme and correcponding gemmconfigs

* updating cmake file

* fixing CI failures for grouped quant gemm

* debugging permuteN

* debugging

* debugging PermuteN

* initial commit

* resolving merge conflicts

* adding test cases

* initial commit with prints

* debugging

* fine-grained working

* debugging medium grained

* fixing the tile window

* formatting

* enabling prefill shapes

* working prefill shapes

* formatted

* clean up

* code cleanup

* bug fix after merging with develop

* clean up after merging with develop

* added comments for the tile window and tile distribution encoding

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Agarwal <khuagarw@ctr2-alola-login-03.amd.com>
This commit is contained in:
Khushbu Agarwal
2026-01-06 12:46:59 -08:00
committed by GitHub
parent 76696ace44
commit aaa35f0bbf
8 changed files with 428 additions and 669 deletions

View File

@@ -48,7 +48,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
@@ -68,7 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
BlockSize,
NPerBlock / WarpGemm::kN,
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
VecLoadSize,
Problem::BQuantGroupSize::kN,
Problem::BQuantGroupSize::kK,
BQLayout,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
@@ -83,6 +83,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
KPerBlockBQ, // Logical K dimension
NPerBlockBQ, // Logical N dimension
Problem::BQuantGroupSize::kN,
Problem::BQuantGroupSize::kK,
BQLayout>;
return TileEncodingPattern::make_2d_static_tile_distribution();

View File

@@ -65,8 +65,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t NPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -300,9 +302,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
(PreshuffleQuant)
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0)
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);

View File

@@ -192,6 +192,7 @@ template <typename BlockGemmShape,
index_t KPerTile,
index_t NPerTile,
index_t NPerQ,
index_t KPerQ,
typename BQLayout = tensor_layout::gemm::ColumnMajor,
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
@@ -208,31 +209,6 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
static_assert(num_warps == MWarps * NWarps * KWarps);
static_assert(KWarps == 1);
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
/// This function determines the optimal thread distribution pattern for loading and applying
/// quantization scales to the B matrix based on the quantization group size (NPerQ) relative
/// to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
///
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
/// - Each warp handles exactly one quantization scale
/// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
///
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
/// - Quantization group spans multiple warps
/// - All warps share the same scale value
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
///
/// @return A static tile distribution encoding for the BQ scale tensor
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
// Preshuffle only supported for ColumnMajor currently
@@ -241,22 +217,136 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
if constexpr(PreshuffleQuant)
{
// ColumnMajor only for preshuffle
constexpr index_t X1 = warp_size;
constexpr index_t X0 = NPerTile / warp_size;
constexpr index_t Y1 = NWarps;
constexpr index_t Y0 = KPerTile / Y1;
// =============================================================================
// PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION
// =============================================================================
// For pre-shuffled quantization, the BQ scale tensor has been reorganized
// (pre-shuffled) to optimize memory access patterns during dequantization.
//
// Tile Dimensions:
// - K-axis (Y in encoding): Corresponds to the K-dimension iteration
// - N-axis (X in encoding): Flattened scale index combining N and K groups
//
// The encoding distributes work across threads such that each thread loads
// the correct pre-shuffled scale for its corresponding B-matrix elements.
// =============================================================================
if constexpr(NPerQ <= WarpGemm::kN)
{
// =========================================================================
// CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN)
// =========================================================================
// Multiple quantization scales exist within a single warp's N-dimension.
// Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp.
//
// Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256
// → 2 scales per warp in N, 2 K-groups per block
constexpr auto N1 = BlockGemmShape::kK /
KPerQ; // Number of K-dimension quantization groups per block,
// Each K-group of KPerQ elements shares the same scale.
constexpr auto N0 =
WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ
// <= WarpGemm::kN, each warp handles multiple scales.
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension
constexpr auto NR0 =
warp_size /
(N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization
constexpr auto K1 = NWarps; // Number of warps distributed along this dimension
constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile
constexpr auto KR = 1; // No replication in K-dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
tuple<sequence<0, 1>, sequence<2>>,
tuple<sequence<0, 1>, sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{});
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 1>, sequence<0, 2, 0, 2, 0>>,
tuple<sequence<0, 1>, sequence<1, 0, 2, 1, 3>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else if constexpr(NPerQ < WarpGemm::kN * NWarps)
{
// =========================================================================
// CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN *
// NWarps)
// =========================================================================
// Each warp handles exactly one quantization scale in N-dimension.
// Some warps share the same scale (KR > 1 creates warp grouping).
//
// Example: NPerQ=32, WarpGemm::kN=16, NWarps=4
// → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups)
constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale
constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales)
constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ)
constexpr auto NR0 =
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 1, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1, 3>, sequence<1, 0, 2, 1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
// =========================================================================
// CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps)
// =========================================================================
// The quantization group spans ALL warps in N-dimension.
// All warps share the same scale value for their N-tiles.
//
// Example: NPerQ=128, WarpGemm::kN=16, NWarps=4
// → 128 >= 16*4=64, so all 4 warps use the same scale
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
constexpr auto N0 = 1; // Minimal (1) since scale is shared across N
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = 32; // Fixed broadcast size
constexpr auto NR0 =
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
tuple<sequence<0, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1>, sequence<2, 0, 3, 1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
}
else
{
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
/// This function determines the optimal thread distribution pattern for loading and
/// applying quantization scales to the B matrix based on the quantization group size
/// (NPerQ) relative to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale
/// broadcast
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
///
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
/// - Each warp handles exactly one quantization scale
/// - Scales are distributed across warps with replication factor XR = NPerQ /
/// WarpGemm::kN
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
///
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
/// - Quantization group spans multiple warps
/// - All warps share the same scale value
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
///
/// @return A static tile distribution encoding for the BQ scale tensor
if constexpr(NPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp

View File

@@ -71,6 +71,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
static constexpr index_t NPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
static constexpr index_t QScalesPerBlockRow =
@@ -352,8 +354,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
}
else
@@ -427,8 +431,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
}
else
@@ -462,8 +468,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
}
else