mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)
* chore: split block scale example instances in more separate files to speed up compile times * wip: fp4 scaffolding for abquant * feat: add fp4 decoding-while-loading to abquant pipeline * feat: add support for fp4 CPU verification in abquant * chore: add time tracking to reference calculation * feat: add a4w4 test for blockscale gemm * feat: optimize reference calculation by preconverting values to AccType * feat: add fp4 to fp8 look-up table * fix: reference to wrong ComputeDataType field in QuantProblem * feat: type utilities for determining MFMA compute types * feat: packed fp4 for abquant weight preshuffle * feat: add separate tests for a4w4 base case, padding and preshuffleB * fix: fp4 conversion on gfx950 attempting to use non-supported method * fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size * chore: add fp4 preshuffleb mode to block scale example * chore: sanity check for packed types being 1 byte * chore: clarify tensor dimension indices with constants * chore: replace traits check with specialized check for packed types * style: some minor refactoring and cleanup * fix: correct conversion table for FNUZ fp8 * chore: add fp4 instances to main abquant instances again * chore: use same initialization branch for int4 and fp4 * chore: add missing initialization for fp4 in block scale gemm example --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -397,6 +397,29 @@ struct PassThroughPack8
|
||||
y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_fp4x4_t& x) const
|
||||
{
|
||||
pk_fp4_t f0 = pk_fp4_t{x[0]};
|
||||
pk_fp4_t f1 = pk_fp4_t{x[1]};
|
||||
pk_fp4_t f2 = pk_fp4_t{x[2]};
|
||||
pk_fp4_t f3 = pk_fp4_t{x[3]};
|
||||
|
||||
fp8x2_t x0 = f0.to_fp8x2();
|
||||
fp8x2_t x1 = f1.to_fp8x2();
|
||||
fp8x2_t x2 = f2.to_fp8x2();
|
||||
fp8x2_t x3 = f3.to_fp8x2();
|
||||
|
||||
y[0] = x0[0];
|
||||
y[1] = x0[1];
|
||||
y[2] = x1[0];
|
||||
y[3] = x1[1];
|
||||
y[4] = x2[0];
|
||||
y[5] = x2[1];
|
||||
y[6] = x3[0];
|
||||
y[7] = x3[1];
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user