[CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165)

* formatted

* formatted

* formatting

* formatting

* formatting

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* enable prefill shapes

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* adding preshuffle quant as new parameter and its associated new files

* remove debugging statements

* adding test

* enable preshuffle quant with permuteN

* updating readme and correcponding gemmconfigs

* updating cmake file

* fixing CI failures for grouped quant gemm

* addressing review comments

* fixing CI issue

* addressing reveiw comments

* formatting

* formatting

* fixing aquant operator overlaoding

* formatting

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-11-24 07:48:42 -08:00
committed by GitHub
parent e857e26bf6
commit 8111572785
31 changed files with 855 additions and 247 deletions

View File

@@ -463,11 +463,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t m,
index_t num_loop,
void* p_smem) const
void* p_smem,
index_t m = 0) const
{
return PipelineImpl<GemmPipelineScheduler::Interwave>{}
.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,

View File

@@ -465,9 +465,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t m,
index_t num_loop,
void* p_smem) const
void* p_smem,
index_t m = 0) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,

View File

@@ -35,30 +35,48 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using TileEncodingPattern =
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
KPerBlockBQ,
NPerBlockBQ,
Problem::QuantGroupSize::kN>;
if constexpr(PreshuffleQuant)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<
BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlock / WarpGemm::kN,
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
VecLoadSize,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
KPerBlockBQ,
NPerBlockBQ,
Problem::QuantGroupSize::kN>;
return TileEncodingPattern::make_2d_static_tile_distribution();
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
template <typename Problem>

View File

@@ -137,6 +137,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -238,6 +239,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t n,
index_t num_loop,
void* p_smem) const
{
@@ -257,9 +259,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Bq block window has incorrect lengths for defined BqLayout!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
@@ -315,8 +314,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BQDramTileWindowStep bq_dram_tile_window_step =
is_bq_col_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ);
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
: is_bq_col_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
@@ -457,6 +460,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
return c_block_tile;
}
};
// Overload for PreshuffleQuant = true
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename BQDramBlockWindowTmp>
@@ -464,7 +468,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
void* p_smem,
index_t n = 0) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
@@ -472,6 +477,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
bq_dram_block_window_tmp,
n,
num_loop,
p_smem);
}
@@ -502,7 +508,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
void* p_smem,
index_t n = 0) const
{
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
constexpr bool hot_loop = has_hot_loop_.value;
@@ -513,6 +520,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
bq_dram_block_window_tmp,
n, // dummy value, won't be used
num_loop,
p_smem);
};

View File

@@ -171,7 +171,8 @@ template <typename BlockGemmShape,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t XPerQ>
index_t XPerQ,
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
static constexpr index_t warp_size = get_warp_size();
@@ -213,52 +214,71 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
/// @return A static tile distribution encoding for the BQ scale tensor
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
if constexpr(XPerQ < WarpGemm::kN)
if constexpr(PreshuffleQuant)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
constexpr index_t Y = YPerTile; // Full Y dimension of tile
constexpr index_t YR = 1; // No Y replication needed
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t X1 = NWarps; // Number of warps in N-dim
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
constexpr index_t XR = XPerQ; // Elements per quantization group
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
constexpr index_t X1 = warp_size;
constexpr index_t X0 = XPerTile / warp_size;
constexpr index_t Y1 = NWarps;
constexpr index_t Y0 = YPerTile / Y1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, YR, XR>,
tuple<sequence<Y>, sequence<X0, X1, X2>>,
tuple<sequence<0, 2>, sequence<0, 2, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<X0, X1>>,
tuple<sequence<0, 2, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else // XPerQ > WarpGemm::kN * NWarps
{
// Case 3: Coarse-grained - quantization group spans all warps
// All warps in N-dimension share the same quantization scale
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tile_distribution_encoding<sequence<MWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<2, 1>,
tuple<sequence<0, 1>, sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
if constexpr(XPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
constexpr index_t Y = YPerTile; // Full Y dimension of tile
constexpr index_t YR = 1; // No Y replication needed
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t X1 = NWarps; // Number of warps in N-dim
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
constexpr index_t XR = XPerQ; // Elements per quantization group
static_assert(X0 * X1 * X2 == XPerTile,
"X0, X1, X2 must cover the blocktile along X.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, YR, XR>,
tuple<sequence<Y>, sequence<X0, X1, X2>>,
tuple<sequence<0, 2>, sequence<0, 2, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<X0, X1>>,
tuple<sequence<0, 2, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else // XPerQ > WarpGemm::kN * NWarps
{
// Case 3: Coarse-grained - quantization group spans all warps
// All warps in N-dimension share the same quantization scale
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
}
}
};

View File

@@ -68,6 +68,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
using Base::m_preload;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
static constexpr index_t QScalesPerBlockRow =
@@ -106,6 +107,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t n,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
@@ -236,7 +238,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// BQ DRAM window for load
auto bq_copy_dram_window =
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<KPerBlockBQ>{}, number<kNPerBlock>{}),
bq_dram_block_window_tmp.get_window_lengths(),
bq_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeBQDramTileDistribution<Problem>());
@@ -269,8 +271,17 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
BQBlockTile bq_block_tile, bq_block_tile_2;
bq_block_tile = load_tile(bq_copy_dram_window);
// move BQ to tile 1
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0});
}
else
{
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
}
// Prefill A0
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
@@ -318,7 +329,17 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile_2 = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0});
}
else
{
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
}
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
@@ -360,7 +381,17 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0});
}
else
{
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
}
// Prefill A(2i+2)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
@@ -448,6 +479,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
return c_block_tile;
}
// Replace lines 485-526 with a single optimized operator:
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename BQDramBlockWindowTmp>
@@ -456,14 +488,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
void* p_smem_pong,
index_t n = 0) const // Default value for non-preshuffle case
{
return operator()<TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
bq_dram_block_window_tmp,
n,
num_loop,
p_smem_ping,
p_smem_pong);
@@ -478,7 +511,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
index_t num_loop,
TailNumber tail_number,
void* p_smem_ping,
void* p_smem_pong) const
void* p_smem_pong,
index_t n = 0) const
{
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
@@ -488,6 +522,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
bq_dram_block_window_tmp,
n, // dummy value, won't be used
num_loop,
p_smem_ping,
p_smem_pong);