[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)

* chore: split block scale example instances in more separate files to speed up compile times

* wip: fp4 scaffolding for abquant

* feat: add fp4 decoding-while-loading to abquant pipeline

* feat: add support for fp4 CPU verification in abquant

* chore: add time tracking to reference calculation

* feat: add a4w4 test for blockscale gemm

* feat: optimize reference calculation by preconverting values to AccType

* feat: add fp4 to fp8 look-up table

* fix: reference to wrong ComputeDataType field in QuantProblem

* feat: type utilities for determining MFMA compute types

* feat: packed fp4 for abquant weight preshuffle

* feat: add separate tests for a4w4 base case, padding and preshuffleB

* fix: fp4 conversion on gfx950 attempting to use non-supported method

* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size

* chore: add fp4 preshuffleb mode to block scale example

* chore: sanity check for packed types being 1 byte

* chore: clarify tensor dimension indices with constants

* chore: replace traits check with specialized check for packed types

* style: some minor refactoring and cleanup

* fix: correct conversion table for FNUZ fp8

* chore: add fp4 instances to main abquant instances again

* chore: use same initialization branch for int4 and fp4

* chore: add missing initialization for fp4 in block scale gemm example

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Erwin Terpstra
2026-01-30 12:40:50 +01:00
committed by GitHub
parent 565fea2645
commit 6a6177a246
28 changed files with 642 additions and 175 deletions

View File

@@ -101,9 +101,11 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
// 4. i4, bf8, (fp8/fp32) -> f32
static_assert(
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<ADataType, ck_tile::pk_int4_t>) &&
std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<AQDataType, float> || std::is_same_v<AQDataType, ck_tile::fp8_t> ||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
(std::is_same_v<BQDataType, float> || std::is_same_v<BQDataType, ck_tile::fp8_t> ||
@@ -189,7 +191,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
typename BFlatBlockTensor,
typename AQBlockTensor,
typename BQBlockTensor,
typename ABlockWindow>
typename ABlockWindow,
index_t UnaryOpSize = 8>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
ABlockTensor& a_warp_tensor,
BFlatBlockTensor& b_warp_tensor,
@@ -249,8 +252,10 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize>(
a_warp_tensor(number<AwarpIter>{}),
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}
// barrier
// Could be deleted

View File

@@ -108,9 +108,11 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
// 4. i4, bf8, (fp8/fp32) -> f32
static_assert(
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<ADataType, ck_tile::pk_int4_t>) &&
std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>) &&
(std::is_same_v<AQDataType, float> || std::is_same_v<AQDataType, ck_tile::fp8_t> ||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
(std::is_same_v<BQDataType, float> || std::is_same_v<BQDataType, ck_tile::fp8_t> ||
@@ -135,12 +137,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
// A/B DataType get converted from PkInt4/PkFp4 during loading
using OverrideADataType = ComputeDataType;
using OverrideBDataType = ComputeDataType;
using Base = BlockGemmQuantBase;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -268,9 +267,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
// If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS
load_int4_tile<OverrideADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}