From b6a66c19e8f3e5f01a7d79afe476d7626e7e0f06 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 30 Jan 2026 11:27:27 +0000 Subject: [PATCH] Finalize cleanup --- .../run_gemm_quant_example.inc | 14 ++-- .../ck_tile/host/reference/reference_gemm.hpp | 14 ++-- .../block_universal_gemm_as_bs_bquant_cr.hpp | 71 ++++++++++++------- .../gemm_mx_pipeline_ag_bg_cr_policy.hpp | 4 ++ .../test_gemm_quant_fixtures.hpp | 14 ++-- 5 files changed, 71 insertions(+), 46 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 37c69db30e..da8e4541dc 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -825,13 +825,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { if constexpr(std::is_same_v) - ck_tile::reference_mxfp4gemm_quant( + ck_tile::reference_mx_gemm_bquant( a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant -CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, - const HostTensor& q, - const HostTensor& b_k_n, - HostTensor& c_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) +CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const std::size_t M = a_m_k.get_length(0); const std::size_t N = b_k_n.get_length(1); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 51bc35efe9..277a249614 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -216,6 +216,7 @@ struct BQuantBlockUniversalGemmAsBsCr using BTypeTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); BTypeTile b_warp_tile_lds_; + // Load from LDS (assumption is that the scale will be applied in the block gemm) template = {}, bool_constant = {}) { - load_int4_tile( - a_warp_tile_, a_block_window); - load_int4_tile( - b_warp_tile_lds_, b_block_window); + // Load tile from LDS - // Apply scale - using BDataTypeRaw = typename std:: - conditional, pk_fp4_t::type, BDataType>::type; + // Do not use load_int4_tile here because it will have support to cast from fp4 to + // compute type, while here we want to only load from LDS and then apply the scale + // and cast later + if constexpr(ALoadTranspose) + { + a_warp_tile_ = load_tile_transpose(a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } - constexpr auto warp_size = get_warp_size(); + if constexpr(BLoadTranspose) + { + b_warp_tile_lds_ = load_tile_transpose(b_block_window); + } + else + { + load_tile(b_warp_tile_lds_, b_block_window); + } + + // Apply scale and cast + using BDataTypeRaw = + std::conditional_t, pk_fp4_t::type, BDataType>; + + constexpr index_t warp_size = get_warp_size(); constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size; constexpr index_t thread_buffer_size = nelements / UnaryOpSize_; const element_wise::DequantPack8 elementwise_op{}; @@ -262,6 +282,22 @@ struct BQuantBlockUniversalGemmAsBsCr static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // B scale register offset + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / GemmTraits::QuantGroupSize::kN * + Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + // Get B scale from thread buffer + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_f = float(scale_reg); + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; // Thread buffers @@ -275,27 +311,12 @@ struct BQuantBlockUniversalGemmAsBsCr BWarpThreadBuffer b_warp_thread_buffer; BLDSThreadBuffer b_lds_thread_buffer; - // BQuant register offset - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN)) - return (nIter * NWarp * WarpGemm::kN) / - GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock + - kQScale; - else - { - return nIter * Traits::KQPerBlock + kQScale; - } - }(); - // Load thread buffer from tile (LDS type) b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // Apply scale to thread buffer and cast - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_f = float(scale_reg); - + // Apply scale to B thread buffer and cast static_for<0, thread_buffer_size, 1>{}([&](auto i) { elementwise_op( b_warp_thread_buffer.template get_as()(i), @@ -303,7 +324,7 @@ struct BQuantBlockUniversalGemmAsBsCr b_scale_f); }); - // Store thread buffer to tile (MMA type) + // Store B thread buffer to tile (MMA type) b_warp_tile_.set_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp index 64fff27fa1..d77a2d1da6 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp @@ -70,6 +70,10 @@ struct GemmMxPipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() { + // If we apply scale before writing to LDS, we need a tile distribution for + // BQuant consistent with global memory reading of matrix B, while + // if we apply scale after reading from LDS, we need a tile distribution for + // BQuant consistent with the MMA instructions layout if constexpr(Problem::BCastPolicy == CastPolicy::AfterLDSRead) { using BQLayout = remove_cvref_t; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 701bd2410d..27905bb292 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -787,13 +787,13 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase) - ck_tile::reference_mxfp4gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + ck_tile::reference_mx_gemm_bquant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant