From 9f069d6e356e25534f4ad2a77ae40801c76cbd2d Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 3 Nov 2025 00:49:20 +0000 Subject: [PATCH] [CK_TILE] B matrix 2D block scale gemm (#3074) * Refactor quant group size to be configurable for M/N/K, not just K * add some asserts for configurations not implemented * start setting of group size for N dimension * enable 2d for reference quant gemm * WIP: trying to figure out tile dstr and/or indexing for scale matrix * WIP * Fix handling of n dim blocks in tile windows etc * remove commented code and enable all tests again * fix formatting * Add more specialized tile distributions * Enable NWarps replication for bquant tile dstr * fix formatting * fix format * Fix some issues from the merge * fix formatting * one more fix to tile dstr, and revert debug initialization * Remove commented code Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * simplify conditions that are needed for tile distributions * only enable the working group sizes in tests * fix formatting * Update tile distribution for 2D bquant * add some documentation and 2d block scale example * fix formatting * Add in Changlog and restructure the quant 2d example * fix CMake * support the change for blockscale 2d * fix the test file --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Cong Ma Co-authored-by: ThomasNing [ROCm/composable_kernel commit: 16e85cf179fd8e98f56d664642d37a6775d7bc4d] --- CHANGELOG.md | 1 + .../38_block_scale_gemm/gemm_quant_basic.cpp | 265 ++++++++---------- .../38_block_scale_gemm/gemm_utils.hpp | 29 +- .../run_gemm_quant_example.inc | 16 +- .../ck_tile/host/reference/reference_gemm.hpp | 9 +- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 16 +- .../block_universal_gemm_as_aquant_bs_cr.hpp | 22 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 58 ++-- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 17 +- .../gemm_aquant_pipeline_ag_bg_cr_base.hpp | 6 +- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 19 +- .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 8 +- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 19 +- .../gemm_bquant_pipeline_ag_bg_cr_base.hpp | 16 +- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 26 +- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 21 +- .../pipeline/gemm_group_quant_utils.hpp | 112 ++++++-- .../pipeline/gemm_quant_pipeline_problem.hpp | 23 +- ...p_bquant_pipeline_ag_bg_cr_base_policy.hpp | 5 +- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 9 +- .../gemm_block_scale/test_gemm_quant_base.hpp | 22 +- .../test_gemm_quant_fixtures.hpp | 65 +++-- .../test_gemm_quant_typed.cpp | 47 +++- .../test_grouped_gemm_util_quant.hpp | 8 +- 24 files changed, 476 insertions(+), 363 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94a2b279bc..213631721f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added WMMA (gfx12) support for FMHA. * Added pooling kernel in CK_TILE * Added top-k sigmoid kernel in CK_TILE +* Added the blockscale 2D support for CK_TILE GEMM. ### Changed diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index edde59081c..b22596537f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -1,6 +1,11 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// This example demonstrates 2D block scale quantization (N×K) for BQuant +// using non-preshuffled configuration. +// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example +// This is currently done separately to avoid too verbose dispatching. + #include #include #include @@ -17,7 +22,7 @@ template float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) @@ -57,11 +62,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmTraits, ComputeDataType>; + // This example only supports BQuant (no AQuant) + // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3 using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseAQuantGemmPipelineAgBgCrMem>; // memory pipeline hardcoded - // for aquant + ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -229,7 +235,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { @@ -266,6 +272,41 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a return 0; } +// Forward declaration for dispatch function +template