mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_Tile] Support for group size 128 for Preshuffle quant for 2d block scale gemm (#3462)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * debugging permuteN * debugging * debugging PermuteN * initial commit * resolving merge conflicts * adding test cases * initial commit with prints * debugging * fine-grained working * debugging medium grained * fixing the tile window * formatting * enabling prefill shapes * working prefill shapes * formatted * clean up * code cleanup * bug fix after merging with develop * G128 working for both prefill and decode shapes for preshufflequant * clean up after merging with develop * fixing group 64 for decode shapes * non preshufflequant working for group size 128 * enable preshuffleb and preshufflequant with variour group sizes * reduce build time by splitting example into diff datatype files * Adding tests for preshuffleQuant * address review comment * fix for gfx1201 * compile time fix for gfx1201 * clang formatted --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Agarwal <khuagarw@ctr2-alola-login-03.amd.com>
This commit is contained in:
@@ -319,7 +319,23 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
// constexpr index_t reg_offset = nIter;
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN))
|
||||
{
|
||||
if constexpr(Traits::NPerBlock ==
|
||||
GemmTraits::QuantGroupSize::kN)
|
||||
return kQScale;
|
||||
else
|
||||
return nIter; // for prefill needs kQscale, for decode needs
|
||||
// nIter
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter;
|
||||
}
|
||||
}();
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
|
||||
@@ -887,23 +887,27 @@ struct QuantGemmKernel
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n =
|
||||
TilePartitioner::NPerBlock /
|
||||
QuantGroupSize::kN; // Number of N-dimension quantization groups per block
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(
|
||||
I1); // Number of N-dimension elements per warp
|
||||
constexpr auto warp_per_group =
|
||||
(QuantGroupSize::kN <
|
||||
warp_n) // Determine how many warps share the same scale in N-dimension
|
||||
? (warp_n / QuantGroupSize::kN)
|
||||
: (QuantGroupSize::kN / warp_n);
|
||||
constexpr auto bqk_per_block =
|
||||
TilePartitioner::KPerBlock /
|
||||
QuantGroupSize::kK; // Number of K-dimension quantization groups per block
|
||||
constexpr auto
|
||||
tile_window_width = // The pre-shuffled layout flattens warp_n ×
|
||||
// bqk_per_block scales per row, Padded up to warp_size
|
||||
// to ensure coalesced memory access.
|
||||
|
||||
// Number of N-dimension quantization groups per block
|
||||
constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / QuantGroupSize::kN
|
||||
: QuantGroupSize::kN / TilePartitioner::NPerBlock;
|
||||
|
||||
// Number of N-dimension elements per warp
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
// Determine how many warps share the same scale in N-dimension
|
||||
constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n)
|
||||
? (warp_n / QuantGroupSize::kN)
|
||||
: (QuantGroupSize::kN / warp_n);
|
||||
|
||||
// Number of K-dimension quantization groups per block
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
// The pre-shuffled layout flattens warp_n ×
|
||||
// bqk_per_block scales per row, Padded up to warp_size
|
||||
// to ensure coalesced memory access.
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
|
||||
|
||||
// Adapts based on fine vs coarse quantization granularity:
|
||||
@@ -916,23 +920,42 @@ struct QuantGemmKernel
|
||||
// height = block_n
|
||||
constexpr auto tile_window_height =
|
||||
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
|
||||
auto block_n_idx =
|
||||
i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a
|
||||
// block index.
|
||||
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
auto block_n_idx = i_n / TilePartitioner::NPerBlock;
|
||||
|
||||
// For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ...
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
block_n_idx = block_n_idx >> 1;
|
||||
}
|
||||
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto tensor_dim =
|
||||
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / QuantGroupSize::kN
|
||||
: 1;
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
number<tensor_dim>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
@@ -940,7 +963,7 @@ struct QuantGemmKernel
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
make_tuple(number<tensor_dim>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
|
||||
@@ -26,14 +26,15 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
(QuantGroupSize::kN <= NPerBlock) ? NPerBlock / QuantGroupSize::kN : 1;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
|
||||
static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
// static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
// "NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
|
||||
|
||||
@@ -45,7 +45,9 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t NPerBlockBQ = (Problem::QuantGroupSize::kN <= NPerBlock)
|
||||
? NPerBlock / Problem::QuantGroupSize::kN
|
||||
: 1;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
|
||||
@@ -66,7 +66,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
|
||||
(QuantGroupSize::kN <= BlockGemmShape::kN)
|
||||
? integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN)
|
||||
: 1;
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
|
||||
|
||||
@@ -240,20 +240,26 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
//
|
||||
// Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256
|
||||
// → 2 scales per warp in N, 2 K-groups per block
|
||||
constexpr auto N1 = BlockGemmShape::kK /
|
||||
KPerQ; // Number of K-dimension quantization groups per block,
|
||||
// Each K-group of KPerQ elements shares the same scale.
|
||||
constexpr auto N0 =
|
||||
WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ
|
||||
// <= WarpGemm::kN, each warp handles multiple scales.
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension
|
||||
|
||||
// N1: Number of K-dimension quantization groups per block,
|
||||
// Each K-group of KPerQ elements shares the same scale.
|
||||
// N0: Number of scales per warp in N-dimension, Since NPerQ
|
||||
// <= WarpGemm::kN, each warp handles multiple scales.
|
||||
// N2: Elements per thread
|
||||
// NR1: Elements sharing the same scale in N-dimension
|
||||
// NR0: Interleave factor to ensure full warp utilization
|
||||
// K1: Number of warps distributed along this dimension
|
||||
// K0: Iterations per warp to cover the K-tile
|
||||
// KR: No replication in K-dimension
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ;
|
||||
constexpr auto N0 = WarpGemm::kN / NPerQ;
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = NPerQ;
|
||||
constexpr auto NR0 =
|
||||
warp_size /
|
||||
(N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization
|
||||
constexpr auto K1 = NWarps; // Number of warps distributed along this dimension
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile
|
||||
constexpr auto KR = 1; // No replication in K-dimension
|
||||
(warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
|
||||
constexpr auto K1 = NWarps;
|
||||
constexpr auto K0 = KPerTile / K1;
|
||||
constexpr auto KR = 1;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
|
||||
@@ -275,15 +281,24 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
// Example: NPerQ=32, WarpGemm::kN=16, NWarps=4
|
||||
// → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups)
|
||||
|
||||
constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale
|
||||
constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales)
|
||||
constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
|
||||
constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ)
|
||||
// KR: Number of warps sharing the same scale
|
||||
// K1: Number of distinct warp groups (unique scales)
|
||||
// K0: Iterations to cover K-tile per warp group
|
||||
// N1: K-dimension quantization groups
|
||||
// N0: Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
|
||||
// N2: Elements per thread
|
||||
// NR1: Scale broadcast factor (full NPerQ)
|
||||
// NR0: Remaining interleave factor
|
||||
|
||||
constexpr auto KR = NPerQ / WarpGemm::kN;
|
||||
constexpr auto K1 = NWarps / KR;
|
||||
constexpr auto K0 = KPerTile / K1;
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ;
|
||||
constexpr auto N0 = 1;
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = NPerQ;
|
||||
constexpr auto NR0 =
|
||||
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
|
||||
(warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
|
||||
@@ -303,12 +318,19 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
//
|
||||
// Example: NPerQ=128, WarpGemm::kN=16, NWarps=4
|
||||
// → 128 >= 16*4=64, so all 4 warps use the same scale
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
|
||||
constexpr auto N0 = 1; // Minimal (1) since scale is shared across N
|
||||
constexpr auto N2 = 1; // Elements per thread
|
||||
constexpr auto NR1 = 32; // Fixed broadcast size
|
||||
|
||||
// N1: K-dimension quantization groups
|
||||
// N0: Minimal (1) since scale is shared across N
|
||||
// N2: Elements per thread
|
||||
// NR1: Fixed broadcast size
|
||||
// NR0: Remaining interleave factor
|
||||
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ;
|
||||
constexpr auto N0 = 1;
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = 32;
|
||||
constexpr auto NR0 =
|
||||
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
|
||||
(warp_size <= (N0 * N1 * N2 * NR1)) ? 1 : warp_size / (N0 * N1 * N2 * NR1);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
|
||||
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
|
||||
|
||||
@@ -79,10 +79,8 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % AQuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0);
|
||||
static_assert(BlockGemmShape::kM % BQuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % BQuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % BQuantGroupSize::kK == 0);
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
|
||||
@@ -144,23 +144,32 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
// Insert LDS read/write groups periodically based on ds_rep.
|
||||
// The % pattern staggers READ and WRITE so they don't collapse
|
||||
// into the same cycle in the model.
|
||||
if constexpr(ds_rep > 0 && i_inst % ds_rep == 0)
|
||||
if constexpr(ds_rep > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
|
||||
}
|
||||
if constexpr(ds_rep > 0 && i_inst % ds_rep == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
|
||||
}
|
||||
|
||||
if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0)
|
||||
{
|
||||
if constexpr(ds_write_inst > 0)
|
||||
if(i_inst % ds_rep == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
|
||||
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read
|
||||
}
|
||||
}
|
||||
if constexpr(ds_rep > 0)
|
||||
{
|
||||
if(i_inst % ds_rep == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(buffer_load_rep > 0)
|
||||
{
|
||||
if(i_inst % buffer_load_rep == 0)
|
||||
{
|
||||
if constexpr(ds_write_inst > 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read
|
||||
}
|
||||
}
|
||||
}
|
||||
// Always mark some VALU work in the loop to reflect auxiliary scalar
|
||||
@@ -354,7 +363,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
@@ -431,7 +440,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
@@ -468,7 +477,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
|
||||
Reference in New Issue
Block a user