From 5c814645680d01466e41ade2d467c11277df781c Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Sun, 14 Dec 2025 19:25:47 +0100 Subject: [PATCH] CK Tile: Enable padding blockscale example (#3417) * Fix host code padding * restructure the ref code * clean up * Fix compilation error --------- Co-authored-by: ThomasNing [ROCm/composable_kernel commit: 21f06aa47ded64b9a07d81bf4b743c21462178db] --- .../38_block_scale_gemm/gemm_utils.hpp | 2 +- .../run_gemm_quant_example.inc | 19 ++-- .../ck_tile/host/reference/reference_gemm.hpp | 99 ++++++++++--------- 3 files changed, 58 insertions(+), 62 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index aabbfff3bd..7a4760e1da 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -91,7 +91,7 @@ struct GemmConfigBase { static constexpr bool kPadM = false; static constexpr bool kPadN = false; - static constexpr bool kPadK = false; + static constexpr bool kPadK = true; static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index fa5e1f12e3..a0e875448d 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -391,25 +391,18 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); - if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if(K % QuantGroupSize::kK != 0) - { - throw std::runtime_error( - "K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode"); - } - } ck_tile::index_t AQK, BQK, BQN = 0; if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { - AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize - BQK = 0; // No B quantization + AQK = ck_tile::integer_divide_ceil( + K, QuantGroupSize::kK); // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - AQK = 0; // No A quantization - BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + AQK = 0; // No A quantization + BQK = ck_tile::integer_divide_ceil( + K, QuantGroupSize::kK); // Group quantization: BQK = K / GroupSize BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant || diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 0aa296b8d9..8b0e3028ae 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -34,77 +34,80 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor& a_m_k, const std::size_t K = a_m_k.get_length(1); auto f_mn = [&](auto m, auto n) { - AccDataType v_acc = 0, v_block_acc = 0; + AccDataType v_acc = 0; - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v); - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v); - static_assert(std::is_same_v); - static_assert(std::is_same_v || - std::is_same_v); - for(std::size_t k = 0; k < K; ++k) - { - AccDataType v_a; - AccDataType v_b; + constexpr std::size_t kGroupK = QuantGroupSize::kK; + + // ---- A loader: dequant A(m,k) into AccDataType ---- + auto load_a = [&](std::size_t k) -> AccDataType { if constexpr(std::is_same_v) { const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - if(k % 2 == 1) - v_a = fp32_val.hi; - else - v_a = fp32_val.lo; + return (k & 1) ? fp32_val.hi : fp32_val.lo; } else { - v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + return ck_tile::type_convert(a_element_op(a_m_k(m, k))); } + }; + + // ---- B loader: dequant B(k,n) into AccDataType ---- + auto load_b = [&](std::size_t k) -> AccDataType { if constexpr(std::is_same_v) { const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - if(k % 2 == 1) - v_b = fp32_val.hi; - else - v_b = fp32_val.lo; + return (k & 1) ? fp32_val.hi : fp32_val.lo; } else if constexpr(std::is_same_v) { - v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n))); + return fp8_to_float_raw(b_element_op(b_k_n(k, n))); } else { - v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + return ck_tile::type_convert(b_element_op(b_k_n(k, n))); } - v_block_acc += v_a * v_b; + }; - // Apply group dequant scale - if((k + 1) % QuantGroupSize::kK == 0) + // ---- scale loader for a given K-group index ---- + auto load_scale = [&](ck_tile::index_t k_group) -> float { + const ck_tile::index_t outer_dim = aquant ? (m / QuantGroupSize::kM) : k_group; + const ck_tile::index_t inner_dim = aquant ? k_group : (n / QuantGroupSize::kN); + + if constexpr(std::is_same_v) { - float scale = 0.f; - index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK); - index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN); - if constexpr(std::is_same_v) - { - scale = q(outer_dim, inner_dim); - } - else if constexpr(std::is_same_v) - { - scale = fp8_to_float_raw(q(outer_dim, inner_dim)); - } - else if constexpr(std::is_same_v) - { - scale = bf8_to_float_raw(q(outer_dim, inner_dim)); - } - else - { - static_assert(false, "Unexpected Q datatype."); - } - v_block_acc *= scale; - v_acc += v_block_acc; - v_block_acc = 0; + return q(outer_dim, inner_dim); } + else if constexpr(std::is_same_v) + { + return fp8_to_float_raw(q(outer_dim, inner_dim)); + } + else // QDataType == bf8_t by static_assert above + { + return bf8_to_float_raw(q(outer_dim, inner_dim)); + } + }; + + // ---- Loop over K by groups (full and tail) ---- + for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK) + { + const std::size_t k_end = std::min(k_begin + kGroupK, K); + + AccDataType v_block_acc = 0; + + // unscaled accumulation within this K-group + for(std::size_t k = k_begin; k < k_end; ++k) + { + const AccDataType v_a = load_a(k); + const AccDataType v_b = load_b(k); + v_block_acc += v_a * v_b; + } + + const ck_tile::index_t k_group = static_cast(k_begin / kGroupK); + const float scale = load_scale(k_group); + + v_acc += v_block_acc * scale; } c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc));