[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)

* chore: split block scale example instances in more separate files to speed up compile times

* wip: fp4 scaffolding for abquant

* feat: add fp4 decoding-while-loading to abquant pipeline

* feat: add support for fp4 CPU verification in abquant

* chore: add time tracking to reference calculation

* feat: add a4w4 test for blockscale gemm

* feat: optimize reference calculation by preconverting values to AccType

* feat: add fp4 to fp8 look-up table

* fix: reference to wrong ComputeDataType field in QuantProblem

* feat: type utilities for determining MFMA compute types

* feat: packed fp4 for abquant weight preshuffle

* feat: add separate tests for a4w4 base case, padding and preshuffleB

* fix: fp4 conversion on gfx950 attempting to use non-supported method

* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size

* chore: add fp4 preshuffleb mode to block scale example

* chore: sanity check for packed types being 1 byte

* chore: clarify tensor dimension indices with constants

* chore: replace traits check with specialized check for packed types

* style: some minor refactoring and cleanup

* fix: correct conversion table for FNUZ fp8

* chore: add fp4 instances to main abquant instances again

* chore: use same initialization branch for int4 and fp4

* chore: add missing initialization for fp4 in block scale gemm example

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Erwin Terpstra
2026-01-30 12:40:50 +01:00
committed by GitHub
parent 565fea2645
commit 6a6177a246
28 changed files with 642 additions and 175 deletions

View File

@@ -6,6 +6,7 @@
#include <cmath>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"
#if defined(__gfx950__)
@@ -23,6 +24,12 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2)));
#else
using fp8x2_t = fp8_t __attribute__((ext_vector_type(2)));
#endif
// Helpers: constexpr-safe access to elements of ext_vector_type(2)
// Some compilers don't allow operator[] in constant expressions for vector types.
// We use bit_cast to a trivially copyable representation to extract lanes.
@@ -98,6 +105,8 @@ struct pk_float4_e2m1_t
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp8_t to_fp8(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp8x2_t to_fp8x2(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
@@ -105,6 +114,8 @@ struct pk_float4_e2m1_t
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
CK_TILE_HOST_DEVICE constexpr operator fp8_t() const { return to_fp8(); }
CK_TILE_HOST_DEVICE constexpr operator fp8x2_t() const { return to_fp8x2(); }
template <index_t I>
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number<I>) const
@@ -145,6 +156,49 @@ struct pk_float4_e2m1_t
bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
};
#if CK_TILE_USE_OCP_FP8
// FP8 EM4E3 (OCP) representation
static constexpr fp8_t e2m1_to_fp8_table[16] = {
fp8_t(static_cast<uint8_t>(0x00)), // 0
fp8_t(static_cast<uint8_t>(0x30)), // 0.5
fp8_t(static_cast<uint8_t>(0x38)), // 1
fp8_t(static_cast<uint8_t>(0x3C)), // 1.5
fp8_t(static_cast<uint8_t>(0x40)), // 2
fp8_t(static_cast<uint8_t>(0x44)), // 3
fp8_t(static_cast<uint8_t>(0x48)), // 4
fp8_t(static_cast<uint8_t>(0x4C)), // 6
fp8_t(static_cast<uint8_t>(0x00)), // -0
fp8_t(static_cast<uint8_t>(0xB0)), // -0.5
fp8_t(static_cast<uint8_t>(0xB8)), // -1
fp8_t(static_cast<uint8_t>(0xBC)), // -1.5
fp8_t(static_cast<uint8_t>(0xC0)), // -2
fp8_t(static_cast<uint8_t>(0xC4)), // -3
fp8_t(static_cast<uint8_t>(0xC8)), // -4
fp8_t(static_cast<uint8_t>(0xCC)) // -6
};
#else // CK_TILE_USE_FNUZ_FP8
// FP8 E4M3 FNUZ
static constexpr fp8_t e2m1_to_fp8_table[16] = {
fp8_t(static_cast<uint8_t>(0x00)), // 0
fp8_t(static_cast<uint8_t>(0x38)), // 0.5
fp8_t(static_cast<uint8_t>(0x40)), // 1
fp8_t(static_cast<uint8_t>(0x44)), // 1.5
fp8_t(static_cast<uint8_t>(0x48)), // 2
fp8_t(static_cast<uint8_t>(0x4C)), // 3
fp8_t(static_cast<uint8_t>(0x50)), // 4
fp8_t(static_cast<uint8_t>(0x54)), // 6
fp8_t(static_cast<uint8_t>(0x00)), // -0
fp8_t(static_cast<uint8_t>(0xB8)), // -0.5
fp8_t(static_cast<uint8_t>(0xC0)), // -1
fp8_t(static_cast<uint8_t>(0xC4)), // -1.5
fp8_t(static_cast<uint8_t>(0xC4)), // -2
fp8_t(static_cast<uint8_t>(0xCC)), // -3
fp8_t(static_cast<uint8_t>(0xD0)), // -4
fp8_t(static_cast<uint8_t>(0xD4)) // -6
};
#endif
#endif
};
@@ -408,6 +462,27 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
#endif
}
CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const
{
// NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8
// would be better than the naive implementation below
// #if CK_TILE_FP4_CVT_DEVICE
// return impl::_from_f4<fp8_t>(data, scale);
// #else
return fp8_t{type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
// #endif
}
CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const
{
// NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8
// would be better than the naive implementation below
// #if CK_TILE_FP4_CVT_DEVICE
// return impl::_from_f4<fp8x2_t>(data, scale);
// #else
return fp8x2_t{type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
// #endif
}
#else
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
{
@@ -415,7 +490,8 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
{
return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale};
return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale,
e2m1_to_fp32_table[_unpack(number<1>{})] * scale};
}
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
{
@@ -428,6 +504,16 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<1>{})]) *
scale)};
}
CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const
{
return type_convert<float>(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale;
}
CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const
{
return fp8x2_t{
type_convert<fp8_t>(type_convert<float>(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale),
type_convert<fp8_t>(type_convert<float>(e2m1_to_fp8_table[_unpack(number<1>{})]) * scale)};
}
#endif
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
@@ -23,6 +24,11 @@ struct pk_int4_t
type data;
CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
// NOTE: added for interface compatibility with pk_fp4_t
// Other data types could be added for greater similarity
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2() const;
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
};
// limits
@@ -186,4 +192,9 @@ CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
return res;
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_int4_t::to_fp32x2() const
{
return pk_int4_t_to_fp32x2_t(*this);
}
} // namespace ck_tile

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/utility/type_traits.hpp"