[CK TILE QUANT GEMM] use OverrideADataType in aquant pipeline (#3584)

This commit is contained in:
Cong Ma
2026-01-16 16:27:39 -07:00
committed by GitHub
parent 3f735c127b
commit f9104ef9b3
2 changed files with 35 additions and 32 deletions

View File

@@ -28,7 +28,11 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
using OverrideADataType =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
@@ -228,9 +232,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"B block window has incorrect lengths for defined BLayout!");
// A/B tiles in LDS - using the same approach as regular gemm pipeline
auto ab_lds_blocks = Base::template GetABLdsTensorViews<BDataType, BDataType>(p_smem);
auto& a_lds_block = ab_lds_blocks.at(I0{});
auto& b_lds_block = ab_lds_blocks.at(I1{});
auto ab_lds_blocks =
Base::template GetABLdsTensorViews<OverrideADataType, BDataType>(p_smem);
auto& a_lds_block = ab_lds_blocks.at(I0{});
auto& b_lds_block = ab_lds_blocks.at(I1{});
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
@@ -260,7 +265,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<BDataType>(ABlockTileDistr{}));
decltype(make_static_distributed_tensor<OverrideADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
using AQBlockTile =
@@ -295,7 +300,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// LDS prefill - VGPRs to LDS
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -346,7 +351,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// Prepare next iteration data
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(
a_shuffle_tmp,
@@ -406,7 +411,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp,
a_block_tiles.get(number<prefetch_idx + 1>{}));
@@ -494,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return PipelineImpl<GemmPipelineScheduler::Intrawave>{}
.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const BDataType& a) { return a; },
[](const OverrideADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,

View File

@@ -25,7 +25,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
using OverrideADataType =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
@@ -164,14 +168,17 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
using Base = PipelineImplBase;
template <typename ADramWindow, typename ABlockTile_>
CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile,
const ADramWindow& a_dram_window)
template <typename ADramWindow, typename ABlockTile_, typename DramTileWindowStep>
CK_TILE_DEVICE static void
LoadAndConvertATile(ABlockTile_& a_block_tile,
ADramWindow& a_dram_window,
const DramTileWindowStep& dram_tile_window_step)
{
using DestDataType = typename ABlockTile_::DataType;
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
move_tile_window(a_dram_window, dram_tile_window_step);
}
template <bool HasHotLoop,
@@ -224,7 +231,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
auto&& [a_lds_block, b_lds_block] =
Base::template GetABLdsTensorViews<BDataType, BDataType>(p_smem);
Base::template GetABLdsTensorViews<OverrideADataType, BDataType>(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
@@ -241,11 +248,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
// while ADatatype might not be the same as BDataType at the time of problem
// initialization, we can safely use BDataType here because when A would be int4 we will
// ensure A is converted to BDataType prior to loading
using ABlockTile =
decltype(make_static_distributed_tensor<BDataType>(ABlockTileDistr{}));
decltype(make_static_distributed_tensor<OverrideADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
using AQBlockTile =
@@ -274,8 +278,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
: (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ));
// DRAM prefetch (global read 0)
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertATile(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
@@ -284,7 +287,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -306,8 +309,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertATile(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
@@ -328,7 +330,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -349,8 +351,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
LoadAndConvertATile(a_block_tile, a_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertATile(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2],
aq_copy_dram_window,
@@ -389,7 +390,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<BDataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -430,10 +431,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
// Note: a_element_func takes BDataType (not ADataType) because A tiles are
// converted from ADataType (e.g., pk_int4_t) to BDataType (e.g., fp8) in
// LoadAndConvertATile before the element function is applied.
[](const BDataType& a) { return a; },
[](const OverrideADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,
@@ -476,7 +474,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr auto tail_num = tail_number_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
[](const OverrideADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,