mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK_Tile] Support for various group sizes Preshuffle quant for 2d block scale gemm (#3445)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * debugging permuteN * debugging * debugging PermuteN * initial commit * resolving merge conflicts * adding test cases * initial commit with prints * debugging * fine-grained working * debugging medium grained * fixing the tile window * formatting * enabling prefill shapes * working prefill shapes * formatted * clean up * code cleanup * bug fix after merging with develop * clean up after merging with develop * added comments for the tile window and tile distribution encoding --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Agarwal <khuagarw@ctr2-alola-login-03.amd.com>
This commit is contained in:
@@ -48,7 +48,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
@@ -68,7 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
BlockSize,
|
||||
NPerBlock / WarpGemm::kN,
|
||||
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
|
||||
VecLoadSize,
|
||||
Problem::BQuantGroupSize::kN,
|
||||
Problem::BQuantGroupSize::kK,
|
||||
BQLayout,
|
||||
PreshuffleQuant>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
@@ -83,6 +83,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
KPerBlockBQ, // Logical K dimension
|
||||
NPerBlockBQ, // Logical N dimension
|
||||
Problem::BQuantGroupSize::kN,
|
||||
Problem::BQuantGroupSize::kK,
|
||||
BQLayout>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
@@ -65,8 +65,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -300,9 +302,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
const BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
0)
|
||||
(PreshuffleQuant)
|
||||
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0)
|
||||
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
|
||||
: make_array(0, KPerBlockBQ);
|
||||
|
||||
|
||||
@@ -192,6 +192,7 @@ template <typename BlockGemmShape,
|
||||
index_t KPerTile,
|
||||
index_t NPerTile,
|
||||
index_t NPerQ,
|
||||
index_t KPerQ,
|
||||
typename BQLayout = tensor_layout::gemm::ColumnMajor,
|
||||
bool PreshuffleQuant = false>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
@@ -208,31 +209,6 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
static_assert(num_warps == MWarps * NWarps * KWarps);
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
|
||||
///
|
||||
/// This function determines the optimal thread distribution pattern for loading and applying
|
||||
/// quantization scales to the B matrix based on the quantization group size (NPerQ) relative
|
||||
/// to warp dimensions.
|
||||
///
|
||||
/// Three distinct distribution patterns are handled:
|
||||
///
|
||||
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
|
||||
/// - Multiple quantization groups exist within a single warp's N-dimension
|
||||
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
|
||||
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast
|
||||
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
|
||||
///
|
||||
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
|
||||
/// - Each warp handles exactly one quantization scale
|
||||
/// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN
|
||||
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
|
||||
///
|
||||
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
|
||||
/// - Quantization group spans multiple warps
|
||||
/// - All warps share the same scale value
|
||||
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
|
||||
///
|
||||
/// @return A static tile distribution encoding for the BQ scale tensor
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
// Preshuffle only supported for ColumnMajor currently
|
||||
@@ -241,22 +217,136 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
// ColumnMajor only for preshuffle
|
||||
constexpr index_t X1 = warp_size;
|
||||
constexpr index_t X0 = NPerTile / warp_size;
|
||||
constexpr index_t Y1 = NWarps;
|
||||
constexpr index_t Y0 = KPerTile / Y1;
|
||||
// =============================================================================
|
||||
// PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION
|
||||
// =============================================================================
|
||||
// For pre-shuffled quantization, the BQ scale tensor has been reorganized
|
||||
// (pre-shuffled) to optimize memory access patterns during dequantization.
|
||||
//
|
||||
// Tile Dimensions:
|
||||
// - K-axis (Y in encoding): Corresponds to the K-dimension iteration
|
||||
// - N-axis (X in encoding): Flattened scale index combining N and K groups
|
||||
//
|
||||
// The encoding distributes work across threads such that each thread loads
|
||||
// the correct pre-shuffled scale for its corresponding B-matrix elements.
|
||||
// =============================================================================
|
||||
if constexpr(NPerQ <= WarpGemm::kN)
|
||||
{
|
||||
// =========================================================================
|
||||
// CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN)
|
||||
// =========================================================================
|
||||
// Multiple quantization scales exist within a single warp's N-dimension.
|
||||
// Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp.
|
||||
//
|
||||
// Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256
|
||||
// → 2 scales per warp in N, 2 K-groups per block
|
||||
constexpr auto N1 = BlockGemmShape::kK /
|
||||
KPerQ; // Number of K-dimension quantization groups per block,
|
||||
// Each K-group of KPerQ elements shares the same scale.
|
||||
constexpr auto N0 =
|
||||
WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ
|
||||
// <= WarpGemm::kN, each warp handles multiple scales.
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension
|
||||
constexpr auto NR0 =
|
||||
warp_size /
|
||||
(N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization
|
||||
constexpr auto K1 = NWarps; // Number of warps distributed along this dimension
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile
|
||||
constexpr auto KR = 1; // No replication in K-dimension
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps>,
|
||||
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 1>, sequence<2>>,
|
||||
tuple<sequence<0, 1>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
|
||||
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2, 0, 2, 0>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 0, 2, 1, 3>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else if constexpr(NPerQ < WarpGemm::kN * NWarps)
|
||||
{
|
||||
// =========================================================================
|
||||
// CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN *
|
||||
// NWarps)
|
||||
// =========================================================================
|
||||
// Each warp handles exactly one quantization scale in N-dimension.
|
||||
// Some warps share the same scale (KR > 1 creates warp grouping).
|
||||
//
|
||||
// Example: NPerQ=32, WarpGemm::kN=16, NWarps=4
|
||||
// → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups)
|
||||
|
||||
constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale
|
||||
constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales)
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
|
||||
constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ)
|
||||
constexpr auto NR0 =
|
||||
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
|
||||
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 1, 0>, sequence<0, 2, 0, 2>>,
|
||||
tuple<sequence<0, 1, 3>, sequence<1, 0, 2, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// =========================================================================
|
||||
// CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps)
|
||||
// =========================================================================
|
||||
// The quantization group spans ALL warps in N-dimension.
|
||||
// All warps share the same scale value for their N-tiles.
|
||||
//
|
||||
// Example: NPerQ=128, WarpGemm::kN=16, NWarps=4
|
||||
// → 128 >= 16*4=64, so all 4 warps use the same scale
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
|
||||
constexpr auto N0 = 1; // Minimal (1) since scale is shared across N
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = 32; // Fixed broadcast size
|
||||
constexpr auto NR0 =
|
||||
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
|
||||
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2, 0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 0, 3, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
|
||||
///
|
||||
/// This function determines the optimal thread distribution pattern for loading and
|
||||
/// applying quantization scales to the B matrix based on the quantization group size
|
||||
/// (NPerQ) relative to warp dimensions.
|
||||
///
|
||||
/// Three distinct distribution patterns are handled:
|
||||
///
|
||||
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
|
||||
/// - Multiple quantization groups exist within a single warp's N-dimension
|
||||
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
|
||||
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale
|
||||
/// broadcast
|
||||
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
|
||||
///
|
||||
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
|
||||
/// - Each warp handles exactly one quantization scale
|
||||
/// - Scales are distributed across warps with replication factor XR = NPerQ /
|
||||
/// WarpGemm::kN
|
||||
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
|
||||
///
|
||||
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
|
||||
/// - Quantization group spans multiple warps
|
||||
/// - All warps share the same scale value
|
||||
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
|
||||
///
|
||||
/// @return A static tile distribution encoding for the BQ scale tensor
|
||||
if constexpr(NPerQ < WarpGemm::kN)
|
||||
{
|
||||
// Case 1: Fine-grained - multiple quantization scales within a single warp
|
||||
|
||||
@@ -71,6 +71,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
@@ -352,8 +354,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
}
|
||||
else
|
||||
@@ -427,8 +431,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
}
|
||||
else
|
||||
@@ -462,8 +468,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user