Support A/B Quantization in Blockscale GEMM (#3343)

* Support A/B Quantization in Blockscale GEMM

* Support A/B Quantization in Blockscale GEMM

* Support A/B Quantization in Blockscale GEMM

* Support A/B Quantization in Blockscale GEMM

* Support A/B Quantization in Blockscale GEMM

* Implement review suggested changes

* Implement review suggested changes

* Sync with develop

* fix pre-commit error

* Add unit tests for blockscale AB-Quantization

* fix pre-commit error

* fix pre-commit error

* fix compile error

* fix compile error

* fix clang-format

* fix clang-format

* fix enumeration values not handled in switch

* rebase file

* Add missing enums to data_type_sizeof (#3430)

Fixes broken build on gfx942. This was some test code that got merged at the same time.

* [CK_BUILDER] CK Tile header installation for builder, algorithm concept improvements (#3419)

* Added install of CK_Tile headers when using CK_EXPERIMENTAL_BUILDER. MIOpen needs this since the builder uses features from CK Tile and the CK Tile install is excluded when doing a narrow build for MIOpen
* Changed algorithm concept type checks to be concepts instead of constexpr bool functions. This improves compiler error messages when using these concepts in static_asserts

---------

Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com>

* Add build trace diagnostics to CI. (#3432)

* generate and visualize build traces for all archs

* generate build traces in all cases

* fix jenkins logic

* fix typo

* use more threads for parsing dependency map

* add script to parse ninja traces and issue warnings

* fix python script syntax and header

* fix python syntax one more time

* fix python syntax

* Support A/B Quantization in Blockscale GEMM

* Implement review suggested changes

* Sync with develop

* Add unit tests for blockscale AB-Quantization

* fix enumeration values not handled in switch

* rebase file

* rebase file

---------

Co-authored-by: John Shumway <jshumway@amd.com>
Co-authored-by: DarylHawkinsAMD <Daryl.Hawkins@amd.com>
Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
kensclin
2025-12-17 23:13:47 +08:00
committed by GitHub
parent 292df2719f
commit 0500fcc017
30 changed files with 2318 additions and 353 deletions

View File

@@ -117,6 +117,132 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
std::cout << std::endl;
}
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AQuantGroupSize,
typename BQuantGroupSize,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_abquant(const HostTensor<ADataType>& a_m_k,
const HostTensor<AQDataType>& a_q,
const HostTensor<BDataType>& b_k_n,
const HostTensor<BQDataType>& b_q,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0, v_block_acc = 0;
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, pk_int4_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> ||
std::is_same_v<CDataType, ck_tile::half_t>);
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_block_acc += v_a * v_b;
// Apply group dequant scale
if((k + 1) % BQuantGroupSize::kK == 0)
{
float a_scale = 0.f;
float b_scale = 0.f;
// A scale
index_t outer_dim = m / AQuantGroupSize::kM;
index_t inner_dim = k / AQuantGroupSize::kK;
if constexpr(std::is_same_v<AQDataType, float>)
{
a_scale = a_q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
{
a_scale = fp8_to_float_raw(a_q(outer_dim, inner_dim));
}
else if constexpr(std::is_same_v<AQDataType, ck_tile::bf8_t>)
{
a_scale = bf8_to_float_raw(a_q(outer_dim, inner_dim));
}
else
{
static_assert(false, "Unexpected Q datatype.");
}
// B scale
outer_dim = k / BQuantGroupSize::kK;
inner_dim = n / BQuantGroupSize::kN;
if constexpr(std::is_same_v<BQDataType, float>)
{
b_scale = b_q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
{
b_scale = fp8_to_float_raw(b_q(outer_dim, inner_dim));
}
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
{
b_scale = bf8_to_float_raw(b_q(outer_dim, inner_dim));
}
else
{
static_assert(false, "Unexpected Q datatype.");
}
v_block_acc = v_block_acc * a_scale * b_scale;
v_acc += v_block_acc;
v_block_acc = 0;
}
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename AQDataType,
typename BDataType,