[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)

* chore: split block scale example instances in more separate files to speed up compile times

* wip: fp4 scaffolding for abquant

* feat: add fp4 decoding-while-loading to abquant pipeline

* feat: add support for fp4 CPU verification in abquant

* chore: add time tracking to reference calculation

* feat: add a4w4 test for blockscale gemm

* feat: optimize reference calculation by preconverting values to AccType

* feat: add fp4 to fp8 look-up table

* fix: reference to wrong ComputeDataType field in QuantProblem

* feat: type utilities for determining MFMA compute types

* feat: packed fp4 for abquant weight preshuffle

* feat: add separate tests for a4w4 base case, padding and preshuffleB

* fix: fp4 conversion on gfx950 attempting to use non-supported method

* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size

* chore: add fp4 preshuffleb mode to block scale example

* chore: sanity check for packed types being 1 byte

* chore: clarify tensor dimension indices with constants

* chore: replace traits check with specialized check for packed types

* style: some minor refactoring and cleanup

* fix: correct conversion table for FNUZ fp8

* chore: add fp4 instances to main abquant instances again

* chore: use same initialization branch for int4 and fp4

* chore: add missing initialization for fp4 in block scale gemm example

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Erwin Terpstra
2026-01-30 12:40:50 +01:00
committed by GitHub
parent 565fea2645
commit 6a6177a246
28 changed files with 642 additions and 175 deletions

View File

@@ -101,9 +101,11 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
// 4. i4, bf8, (fp8/fp32) -> f32
static_assert(
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<ADataType, ck_tile::pk_int4_t>) &&
std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<AQDataType, float> || std::is_same_v<AQDataType, ck_tile::fp8_t> ||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
(std::is_same_v<BQDataType, float> || std::is_same_v<BQDataType, ck_tile::fp8_t> ||
@@ -189,7 +191,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
typename BFlatBlockTensor,
typename AQBlockTensor,
typename BQBlockTensor,
typename ABlockWindow>
typename ABlockWindow,
index_t UnaryOpSize = 8>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
ABlockTensor& a_warp_tensor,
BFlatBlockTensor& b_warp_tensor,
@@ -249,8 +252,10 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize>(
a_warp_tensor(number<AwarpIter>{}),
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
// Could be deleted

View File

@@ -108,9 +108,11 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
// 4. i4, bf8, (fp8/fp32) -> f32
static_assert(
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<ADataType, ck_tile::pk_int4_t>) &&
std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<AQDataType, float> || std::is_same_v<AQDataType, ck_tile::fp8_t> ||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
(std::is_same_v<BQDataType, float> || std::is_same_v<BQDataType, ck_tile::fp8_t> ||
@@ -135,12 +137,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
// A/B DataType get converted from PkInt4/PkFp4 during loading
using OverrideADataType = ComputeDataType;
using OverrideBDataType = ComputeDataType;
using Base = BlockGemmQuantBase;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -268,9 +267,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
// If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS
load_int4_tile<OverrideADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}

View File

@@ -10,9 +10,10 @@
namespace ck_tile {
struct GemmABQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy
struct GemmABQuantPipelineAgBgCrDefaultPolicy
: public UniversalGemmBasePolicy<GemmABQuantPipelineAgBgCrDefaultPolicy>
{
using Base = UniversalGemmPipelineAgBgCrPolicy;
using Base = UniversalGemmBasePolicy<GemmABQuantPipelineAgBgCrDefaultPolicy>;
using Base::I0;
using Base::I1;
using Base::I2;

View File

@@ -34,9 +34,6 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
@@ -67,6 +64,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
// A/B DataType gets converted from PkInt4/PkFp4 during loading
using OverrideADataType = BlockGemm::OverrideADataType;
using OverrideBDataType = BlockGemm::OverrideBDataType;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
@@ -281,9 +282,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
// Note: BDataType PkInt4 gets converted during loading, before going to LDS
// Note: A/B DataType PkInt4/PkFp4 gets converted during loading, before going to LDS
auto&& [a_lds_block, b_lds_block] =
Base::template GetABLdsTensorViews<ADataType, OverrideBDataType>(p_smem);
Base::template GetABLdsTensorViews<OverrideADataType, OverrideBDataType>(p_smem);
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
@@ -303,9 +304,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
decltype(make_static_distributed_tensor<OverrideADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
decltype(make_static_distributed_tensor<OverrideBDataType>(BBlockTileDistr{}));
using AQBlockTile =
decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
using BQBlockTile =
@@ -361,7 +362,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -373,7 +374,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -409,7 +410,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
// Note: ABDataType PkInt4/PkFp4 gets converted during loading earlier
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -420,7 +422,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
}
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
// Note: BDataType PkInt4 gets converted during loading earlier
// Note: BDataType PkInt4/PkFp4 gets converted during loading earlier
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
@@ -493,7 +495,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
// Note: ADataType gets converted during loading from PkInt4/PkFp4
auto a_shuffle_tmp = make_static_distributed_tensor<OverrideADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
@@ -543,9 +546,9 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
[](const OverrideADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
[](const OverrideBDataType& b) { return b; },
aq_dram_block_window_tmp,
bq_dram_block_window_tmp,
m,
@@ -593,9 +596,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
// Note: ADataType PkInt4/PkFp4 gets converted during loading
[](const OverrideADataType& a) { return a; },
b_dram_block_window_tmp,
// Note: BDataType PkInt4 gets converted during loading
// Note: BDataType PkInt4/PkFp4 gets converted during loading
[](const OverrideBDataType& b) { return b; },
aq_dram_block_window_tmp,
bq_dram_block_window_tmp,

View File

@@ -21,23 +21,27 @@ template <typename ADataType_,
typename AQuantGroupSize_,
typename BQuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
typename ComputeDataType_ = void,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>
struct GemmQuantPipelineProblemBase
: public GemmPipelineProblemBase<
ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>
{
using Base = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
using Base = GemmPipelineProblemBase<
ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
mixed_prec_compute_type_t<ComputeDataType_, ADataType_, BDataType_>>;
using Traits = typename Base::Traits;

View File

@@ -95,11 +95,6 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using BTypeToUse =
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
@@ -107,8 +102,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
BTypeToUse,
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),

View File

@@ -7,6 +7,7 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
@@ -239,36 +240,42 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
// A DRAM tile window for load
auto a_dram_tile_distribution =
PipelinePolicy::template MakeADramTileDistribution<Problem>();
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeADramTileDistribution<Problem>());
a_dram_tile_distribution);
auto a_copy_lds_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
a_dram_tile_distribution);
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
a_dram_tile_distribution);
// ping-pong window for A LDS
auto a_warp_tile_distribution =
make_static_tile_distribution(typename WG::AWarpDstrEncoding{});
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
a_warp_tile_distribution);
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
a_warp_tile_distribution);
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
@@ -314,7 +321,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
b_flat_distribution);
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
mixed_prec_compute_type_from_input_t<BDataType, ADataType, ComputeDataType>;
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
// pingpong buffer for B
@@ -354,7 +361,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -393,15 +400,17 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
block_sync_lds();
// preload A00,A10 from lds
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor;
using ATypeToUse =
mixed_prec_compute_type_from_input_t<ADataType, BDataType, ComputeDataType>;
using ATileType =
decltype(make_static_distributed_tensor<BTypeToUse>(a_warp_tile_distribution));
statically_indexed_array<ATileType, m_preload> a_warp_tensor;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0);
@@ -434,7 +443,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -450,8 +459,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
// Next K
@@ -463,7 +472,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -495,8 +504,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
iCounter--;
HotLoopScheduler<loop_count>();
@@ -513,7 +522,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -535,8 +544,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
// GEMM loopK