[CK_TILE] Support more layouts for BQuant GEMM (#3349)

* WIP: preparing to add transpose bq support

* WIP: handle both row/col layout for BQ windows/tile dstr

* Fix build

* WIP: adding some test, debugging numerical errors

* Fix all but pkint4 tests

* Remove test_gemm_quant_typed.cpp again

* update disabled tests

* add conversion from pkint4 for b matrix

* fix formatting

* fix formatting

* Fix tr_load and use override b datatype for clarity

* fix formatting

* make bquant preshuffle tests bqlayout column-major
This commit is contained in:
Sami Remes
2025-12-08 21:05:56 +00:00
committed by GitHub
parent fe07b5a1bf
commit c363a98d41
10 changed files with 359 additions and 189 deletions

View File

@@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
@@ -154,6 +155,10 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using Base = BlockGemmBQuantBase<Problem_>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -271,12 +276,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}
// C += A * B
@@ -397,11 +410,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
MakeCBlockTile();
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// C += A * B

View File

@@ -426,7 +426,6 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
@@ -781,7 +780,9 @@ struct QuantGemmKernel
{
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
@@ -791,14 +792,35 @@ struct QuantGemmKernel
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
// For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN]
// Dimensions: [K/QuantGroupK, N/QuantGroupN]
// Strides: [N/QuantGroupN, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
// For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK]
// Dimensions: [N/QuantGroupN, K/QuantGroupK]
// Strides: [K/QuantGroupK, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
}
}
else
@@ -1023,10 +1045,10 @@ struct QuantGemmKernel
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
@@ -1042,13 +1064,23 @@ struct QuantGemmKernel
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
}
}
else

View File

@@ -42,14 +42,18 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
CK_TILE_DEVICE constexpr auto
GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using YPerTile = number<NPerBlockBQ>;
using XPerTile = number<KPerBlockBQ>;
using YPerTile =
std::conditional_t<std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
number<NPerBlockBQ>,
number<KPerBlockBQ>>;
using XPerTile =
std::conditional_t<std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
number<KPerBlockBQ>,
number<NPerBlockBQ>>;
auto bq_copy_dram_window =
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
make_tuple(YPerTile{}, XPerTile{}),
bq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBQDramTileDistribution<Problem>());
return bq_copy_dram_window;

View File

@@ -25,8 +25,16 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
// Support both RowMajor and ColumnMajor layouts for BQ
if constexpr(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetABQGlobalVectorLoadSize<Problem, BQDataType, KPerBlockBQ, NPerBlockBQ>();
}
else
{
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
}
}
template <typename Problem>
@@ -52,7 +60,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
if constexpr(PreshuffleQuant)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<
@@ -62,18 +69,21 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
NPerBlock / WarpGemm::kN,
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
VecLoadSize,
BQLayout,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
// KPerTile and NPerTile are LOGICAL dimensions (K quant groups and N quant groups)
using TileEncodingPattern =
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlockBQ,
KPerBlockBQ,
Problem::QuantGroupSize::kN>;
KPerBlockBQ, // Logical K dimension
NPerBlockBQ, // Logical N dimension
Problem::QuantGroupSize::kN,
BQLayout>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}

View File

@@ -33,6 +33,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -83,6 +87,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -125,7 +132,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t BQ_Buffer_Load_Inst_Num =
NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
NPerBlockBQ * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
@@ -167,6 +174,16 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
using Base = PipelineImplBase;
template <typename BDramWindow, typename BBlockTile_>
CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile,
const BDramWindow& b_dram_window)
{
using DestDataType = typename BBlockTile_::DataType;
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -194,11 +211,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_bq_col_major =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
constexpr bool is_bq_row_major =
std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
@@ -212,12 +227,22 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
static_assert(
PreshuffleQuant ||
(is_bq_row_major
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])),
"Bq block window has incorrect lengths for defined BqLayout!");
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// Note: BDataType PkInt4 gets converted during loading, before going to LDS
auto&& [a_lds_block, b_lds_block] =
Base::template GetABLdsTensorViews<ADataType, OverrideBDataType>(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
@@ -237,7 +262,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
decltype(make_static_distributed_tensor<ADataType>(BBlockTileDistr{}));
using BQBlockTile =
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
@@ -258,18 +283,20 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
: is_bq_col_major ? make_array(0, KPerBlockBQ)
: make_array(KPerBlockBQ, 0);
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// B tile gets converted to A datatype during loading
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -281,9 +308,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
// B datatype is converted to A datatype during loading
auto b_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -294,11 +322,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
@@ -311,7 +341,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -322,9 +352,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
// Note: BDataType PkInt4 gets converted during loading earlier
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -335,7 +366,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);
@@ -347,7 +379,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
i += 1;
@@ -383,7 +416,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
// Note: BDataType gets converted during loading from PkInt4
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -393,7 +427,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(
c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
}
@@ -415,7 +450,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
// Note: BDataType PkInt4 gets converted during loading
[](const OverrideBDataType& b) { return b; },
bq_dram_block_window_tmp,
n,
num_loop,
@@ -458,7 +494,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
// Note: BDataType PkInt4 gets converted during loading
[](const OverrideBDataType& b) { return b; },
bq_dram_block_window_tmp,
n, // dummy value, won't be used
num_loop,

View File

@@ -189,9 +189,10 @@ struct tile_distribution_encoding_pattern_aq_transposed_c
template <typename BlockGemmShape,
typename WarpGemm,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t YPerQ,
index_t KPerTile,
index_t NPerTile,
index_t NPerQ,
typename BQLayout = tensor_layout::gemm::ColumnMajor,
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
@@ -210,36 +211,41 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
/// This function determines the optimal thread distribution pattern for loading and applying
/// quantization scales to the B matrix based on the quantization group size (XPerQ) relative
/// quantization scales to the B matrix based on the quantization group size (NPerQ) relative
/// to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
/// 1. Fine-grained quantization (XPerQ < WarpGemm::kN):
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast
/// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
///
/// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps):
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
/// - Each warp handles exactly one quantization scale
/// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN
/// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
/// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
///
/// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps):
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
/// - Quantization group spans multiple warps
/// - All warps share the same scale value
/// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
///
/// @return A static tile distribution encoding for the BQ scale tensor
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
// Preshuffle only supported for ColumnMajor currently
static_assert(!(PreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
"PreshuffleQuant only supported for ColumnMajor BQLayout");
if constexpr(PreshuffleQuant)
{
// ColumnMajor only for preshuffle
constexpr index_t X1 = warp_size;
constexpr index_t X0 = XPerTile / warp_size;
constexpr index_t X0 = NPerTile / warp_size;
constexpr index_t Y1 = NWarps;
constexpr index_t Y0 = YPerTile / Y1;
constexpr index_t Y0 = KPerTile / Y1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>,
@@ -251,52 +257,97 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
}
else
{
if constexpr(YPerQ < WarpGemm::kN)
if constexpr(NPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
constexpr index_t X = XPerTile; // Full X dimension of tile
constexpr index_t XR = 1; // No Y replication needed
constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t Y1 = NWarps; // Number of warps in N-dim
constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp
constexpr index_t YR = YPerQ; // Elements per quantization group
// N dimension needs to be partitioned the same way regardless of layout
constexpr index_t NR = 1; // No N replication needed
constexpr index_t N0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t N1 = NWarps; // Number of warps in N-dim
constexpr index_t N2 = WarpGemm::kN / NPerQ; // Number of scales per warp
static_assert(Y0 * Y1 * Y2 == YPerTile,
"Y0, Y1, Y2 must cover the blocktile along Y.");
static_assert(N0 * N1 * N2 == NPerTile,
"N0, N1, N2 must cover the blocktile along N dimension.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, YR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<0, 1>, sequence<0, 1, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
{
// ColumnMajor: [(N0, N1, N2), K] - N on Y-axis, partition Y
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR, NPerQ>,
tuple<sequence<N0, N1, N2>, sequence<KPerTile>>,
tuple<sequence<0, 1>, sequence<0, 1, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// RowMajor: [K, (N0, N1, N2)] - N on X-axis, partition X
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR, NPerQ>,
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
tuple<sequence<0, 2>, sequence<0, 2, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
}
else if constexpr(YPerQ <= WarpGemm::kN * NWarps)
else if constexpr(NPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto Y1 = NWarps / YR; // Warps per unique scale
constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, YR, get_warp_size()>,
tuple<sequence<Y0, Y1>, sequence<XPerTile>>,
tuple<sequence<0, 1, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<1, 2>,
sequence<0, 0>>{});
constexpr auto NR = NPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto N1 = NWarps / NR; // Warps per unique scale
constexpr auto N0 = NPerTile / N1; // Iterations to cover N dimension
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
{
// ColumnMajor: [(N0, N1), K] - N on Y-axis
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR, get_warp_size()>,
tuple<sequence<N0, N1>, sequence<KPerTile>>,
tuple<sequence<0, 1, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// RowMajor: [K, (N0, N1)] - N on X-axis
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR, get_warp_size()>,
tuple<sequence<KPerTile>, sequence<N0, N1>>,
tuple<sequence<0, 2, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
}
else // XPerQ > WarpGemm::kN * NWarps
else // NPerQ > WarpGemm::kN * NWarps
{
// Case 3: Coarse-grained - quantization group spans all warps
// All warps in N-dimension share the same quantization scale
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
{
// ColumnMajor: [N, K]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<NPerTile>, sequence<KPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// RowMajor: [K, N]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<KPerTile>, sequence<NPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
}
}
}