mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)
* chore: split block scale example instances in more separate files to speed up compile times * wip: fp4 scaffolding for abquant * feat: add fp4 decoding-while-loading to abquant pipeline * feat: add support for fp4 CPU verification in abquant * chore: add time tracking to reference calculation * feat: add a4w4 test for blockscale gemm * feat: optimize reference calculation by preconverting values to AccType * feat: add fp4 to fp8 look-up table * fix: reference to wrong ComputeDataType field in QuantProblem * feat: type utilities for determining MFMA compute types * feat: packed fp4 for abquant weight preshuffle * feat: add separate tests for a4w4 base case, padding and preshuffleB * fix: fp4 conversion on gfx950 attempting to use non-supported method * fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size * chore: add fp4 preshuffleb mode to block scale example * chore: sanity check for packed types being 1 byte * chore: clarify tensor dimension indices with constants * chore: replace traits check with specialized check for packed types * style: some minor refactoring and cleanup * fix: correct conversion table for FNUZ fp8 * chore: add fp4 instances to main abquant instances again * chore: use same initialization branch for int4 and fp4 * chore: add missing initialization for fp4 in block scale gemm example --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -101,9 +101,11 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
// 4. i4, bf8, (fp8/fp32) -> f32
|
||||
static_assert(
|
||||
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::pk_int4_t>) &&
|
||||
std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>) &&
|
||||
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>) &&
|
||||
(std::is_same_v<AQDataType, float> || std::is_same_v<AQDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<BQDataType, float> || std::is_same_v<BQDataType, ck_tile::fp8_t> ||
|
||||
@@ -189,7 +191,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
typename BFlatBlockTensor,
|
||||
typename AQBlockTensor,
|
||||
typename BQBlockTensor,
|
||||
typename ABlockWindow>
|
||||
typename ABlockWindow,
|
||||
index_t UnaryOpSize = 8>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
ABlockTensor& a_warp_tensor,
|
||||
BFlatBlockTensor& b_warp_tensor,
|
||||
@@ -249,8 +252,10 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
|
||||
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize>(
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
// Could be deleted
|
||||
|
||||
@@ -108,9 +108,11 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
// 4. i4, bf8, (fp8/fp32) -> f32
|
||||
static_assert(
|
||||
(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::pk_int4_t>) &&
|
||||
std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>) &&
|
||||
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>) &&
|
||||
(std::is_same_v<AQDataType, float> || std::is_same_v<AQDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<BQDataType, float> || std::is_same_v<BQDataType, ck_tile::fp8_t> ||
|
||||
@@ -135,12 +137,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
// BDataType gets converted from PkInt4 during loading
|
||||
using OverrideBDataType = std::conditional_t<
|
||||
std::is_same_v<BDataType, pk_int4_t> &&
|
||||
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
// A/B DataType get converted from PkInt4/PkFp4 during loading
|
||||
using OverrideADataType = ComputeDataType;
|
||||
using OverrideBDataType = ComputeDataType;
|
||||
|
||||
using Base = BlockGemmQuantBase;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
@@ -268,9 +267,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
// If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS
|
||||
load_int4_tile<OverrideADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
a_warp_tile_, a_block_window);
|
||||
// If B datatype were pkint4 it would be converted prior to storing in LDS
|
||||
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
|
||||
b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user