CK Tile: Enable padding blockscale example (#3417)

* Fix host code padding

* restructure the ref code

* clean up

* Fix compilation error

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Enrico Degregori
2025-12-14 19:25:47 +01:00
committed by GitHub
parent 6219b12730
commit 21f06aa47d
3 changed files with 58 additions and 62 deletions

View File

@@ -34,77 +34,80 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0, v_block_acc = 0;
AccDataType v_acc = 0;
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, pk_int4_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> ||
std::is_same_v<CDataType, ck_tile::half_t>);
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a;
AccDataType v_b;
constexpr std::size_t kGroupK = QuantGroupSize::kK;
// ---- A loader: dequant A(m,k) into AccDataType ----
auto load_a = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
return (k & 1) ? fp32_val.hi : fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
};
// ---- B loader: dequant B(k,n) into AccDataType ----
auto load_b = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
return (k & 1) ? fp32_val.hi : fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
return fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_block_acc += v_a * v_b;
};
// Apply group dequant scale
if((k + 1) % QuantGroupSize::kK == 0)
// ---- scale loader for a given K-group index ----
auto load_scale = [&](ck_tile::index_t k_group) -> float {
const ck_tile::index_t outer_dim = aquant ? (m / QuantGroupSize::kM) : k_group;
const ck_tile::index_t inner_dim = aquant ? k_group : (n / QuantGroupSize::kN);
if constexpr(std::is_same_v<QDataType, float>)
{
float scale = 0.f;
index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
if constexpr(std::is_same_v<QDataType, float>)
{
scale = q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
scale = fp8_to_float_raw(q(outer_dim, inner_dim));
}
else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
{
scale = bf8_to_float_raw(q(outer_dim, inner_dim));
}
else
{
static_assert(false, "Unexpected Q datatype.");
}
v_block_acc *= scale;
v_acc += v_block_acc;
v_block_acc = 0;
return q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
return fp8_to_float_raw(q(outer_dim, inner_dim));
}
else // QDataType == bf8_t by static_assert above
{
return bf8_to_float_raw(q(outer_dim, inner_dim));
}
};
// ---- Loop over K by groups (full and tail) ----
for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK)
{
const std::size_t k_end = std::min<std::size_t>(k_begin + kGroupK, K);
AccDataType v_block_acc = 0;
// unscaled accumulation within this K-group
for(std::size_t k = k_begin; k < k_end; ++k)
{
const AccDataType v_a = load_a(k);
const AccDataType v_b = load_b(k);
v_block_acc += v_a * v_b;
}
const ck_tile::index_t k_group = static_cast<ck_tile::index_t>(k_begin / kGroupK);
const float scale = load_scale(k_group);
v_acc += v_block_acc * scale;
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));