mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)
* chore: split block scale example instances in more separate files to speed up compile times
* wip: fp4 scaffolding for abquant
* feat: add fp4 decoding-while-loading to abquant pipeline
* feat: add support for fp4 CPU verification in abquant
* chore: add time tracking to reference calculation
* feat: add a4w4 test for blockscale gemm
* feat: optimize reference calculation by preconverting values to AccType
* feat: add fp4 to fp8 look-up table
* fix: reference to wrong ComputeDataType field in QuantProblem
* feat: type utilities for determining MFMA compute types
* feat: packed fp4 for abquant weight preshuffle
* feat: add separate tests for a4w4 base case, padding and preshuffleB
* fix: fp4 conversion on gfx950 attempting to use non-supported method
* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size
* chore: add fp4 preshuffleb mode to block scale example
* chore: sanity check for packed types being 1 byte
* chore: clarify tensor dimension indices with constants
* chore: replace traits check with specialized check for packed types
* style: some minor refactoring and cleanup
* fix: correct conversion table for FNUZ fp8
* chore: add fp4 instances to main abquant instances again
* chore: use same initialization branch for int4 and fp4
* chore: add missing initialization for fp4 in block scale gemm example
---------
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
[ROCm/composable_kernel commit: 6a6177a246]
This commit is contained in:
@@ -137,47 +137,55 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor<ADataType>& a_m_k,
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
constexpr auto A_TENSOR_M_DIM = 0;
|
||||
constexpr auto A_TENSOR_K_DIM = 1;
|
||||
constexpr auto B_TENSOR_K_DIM = 0;
|
||||
constexpr auto B_TENSOR_N_DIM = 1;
|
||||
|
||||
const std::size_t M = a_m_k.get_length(A_TENSOR_M_DIM);
|
||||
const std::size_t N = b_k_n.get_length(B_TENSOR_N_DIM);
|
||||
const std::size_t K = a_m_k.get_length(A_TENSOR_K_DIM);
|
||||
|
||||
// Pre-convert A/B tensors to AccData type
|
||||
// This prevents doing slow reconversions for each row/column
|
||||
HostTensor<AccDataType> a_acc(a_m_k.mDesc);
|
||||
HostTensor<AccDataType> b_acc(b_k_n.mDesc);
|
||||
|
||||
a_acc.ForEach([&](auto& self, auto index) {
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const ADataType pk_val = a_element_op(a_m_k(index));
|
||||
const fp32x2_t fp32_val = pk_val.to_fp32x2();
|
||||
self(index) = (index[A_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
self(index) = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(index)));
|
||||
}
|
||||
});
|
||||
|
||||
b_acc.ForEach([&](auto& self, auto index) {
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t> || std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const BDataType pk_val = b_element_op(b_k_n(index));
|
||||
const fp32x2_t fp32_val = pk_val.to_fp32x2();
|
||||
self(index) = (index[B_TENSOR_K_DIM] & 1) ? fp32_val.hi : fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
self(index) = fp8_to_float_raw(b_element_op(b_k_n(index)));
|
||||
}
|
||||
else
|
||||
{
|
||||
self(index) = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(index)));
|
||||
}
|
||||
});
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
constexpr std::size_t kGroupK = BQuantGroupSize::kK;
|
||||
|
||||
// ---- A loader: dequant A(m,k) into AccDataType ----
|
||||
auto load_a = [&](std::size_t k) -> AccDataType {
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
return (k & 1) ? fp32_val.hi : fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
};
|
||||
|
||||
// ---- B loader: dequant B(k,n) into AccDataType ----
|
||||
auto load_b = [&](std::size_t k) -> AccDataType {
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
return (k & 1) ? fp32_val.hi : fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
return fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
};
|
||||
|
||||
// ---- a scale loader for a given K-group index ----
|
||||
auto load_scale_a = [&](ck_tile::index_t k_group) -> float {
|
||||
const ck_tile::index_t outer_dim = m / AQuantGroupSize::kM;
|
||||
@@ -224,8 +232,8 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor<ADataType>& a_m_k,
|
||||
// unscaled accumulation within this K-group
|
||||
for(std::size_t k = k_begin; k < k_end; ++k)
|
||||
{
|
||||
const AccDataType v_a = load_a(k);
|
||||
const AccDataType v_b = load_b(k);
|
||||
const AccDataType v_a = a_acc(m, k);
|
||||
const AccDataType v_b = b_acc(k, n);
|
||||
v_block_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user