[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)

* chore: split block scale example instances in more separate files to speed up compile times

* wip: fp4 scaffolding for abquant

* feat: add fp4 decoding-while-loading to abquant pipeline

* feat: add support for fp4 CPU verification in abquant

* chore: add time tracking to reference calculation

* feat: add a4w4 test for blockscale gemm

* feat: optimize reference calculation by preconverting values to AccType

* feat: add fp4 to fp8 look-up table

* fix: reference to wrong ComputeDataType field in QuantProblem

* feat: type utilities for determining MFMA compute types

* feat: packed fp4 for abquant weight preshuffle

* feat: add separate tests for a4w4 base case, padding and preshuffleB

* fix: fp4 conversion on gfx950 attempting to use non-supported method

* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size

* chore: add fp4 preshuffleb mode to block scale example

* chore: sanity check for packed types being 1 byte

* chore: clarify tensor dimension indices with constants

* chore: replace traits check with specialized check for packed types

* style: some minor refactoring and cleanup

* fix: correct conversion table for FNUZ fp8

* chore: add fp4 instances to main abquant instances again

* chore: use same initialization branch for int4 and fp4

* chore: add missing initialization for fp4 in block scale gemm example

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Erwin Terpstra
2026-01-30 12:40:50 +01:00
committed by GitHub
parent 565fea2645
commit 6a6177a246
28 changed files with 642 additions and 175 deletions

View File

@@ -397,6 +397,29 @@ struct PassThroughPack8
y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
#endif
}
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_fp4x4_t& x) const
{
pk_fp4_t f0 = pk_fp4_t{x[0]};
pk_fp4_t f1 = pk_fp4_t{x[1]};
pk_fp4_t f2 = pk_fp4_t{x[2]};
pk_fp4_t f3 = pk_fp4_t{x[3]};
fp8x2_t x0 = f0.to_fp8x2();
fp8x2_t x1 = f1.to_fp8x2();
fp8x2_t x2 = f2.to_fp8x2();
fp8x2_t x3 = f3.to_fp8x2();
y[0] = x0[0];
y[1] = x0[1];
y[2] = x1[0];
y[3] = x1[1];
y[4] = x2[0];
y[5] = x2[1];
y[6] = x3[0];
y[7] = x3[1];
}
constexpr const static bool is_pack8_invocable = true;
};