From bad7262507641389549e41ccf70d1807a71d943e Mon Sep 17 00:00:00 2001 From: Vijay Krish Date: Thu, 28 Aug 2025 23:43:02 -0700 Subject: [PATCH] ck_tile kernel for gemm with groupwise quantized B tensor. (#2663) * This change introduces new pipelines with Intrawave scheduler and block gemm primitives that loads the scale tensor to registers to perform dequantization post MFMA on C tensor in registers. Scale tensor data, BQ is spliced across threads in registers and not stored in LDS. Current support is for the following combinations, but it should be fairly straightforward to extend support to more formats. fp8, fp8 -> f32 bf8, bf8 -> f32 fp8, i4 -> f32 bf8, i4 -> f32 Group size can go down to as low as K length of underlying WarpGemm primitive. * Solve merge conflict * [CK TILE] Update CHANGELOG.md --------- Co-authored-by: Vijay Krishnamoorthy Co-authored-by: ThomasNing Co-authored-by: Cong Ma [ROCm/composable_kernel commit: 4208e2898818362735e1ae9980a4cc2fea607ab4] --- CHANGELOG.md | 1 + .../38_block_scale_gemm/CMakeLists.txt | 5 +- example/ck_tile/38_block_scale_gemm/README.md | 1 + .../38_block_scale_gemm/gemm_aquant_basic.cpp | 8 +- .../38_block_scale_gemm/gemm_bquant_basic.cpp | 229 ++++++ .../38_block_scale_gemm/gemm_utils.hpp | 8 +- .../run_gemm_aquant_example.inc | 1 + .../run_gemm_bquant_example.inc | 286 ++++++++ include/ck_tile/core/numeric/pk_fp4.hpp | 6 +- include/ck_tile/ops/gemm_group_quant.hpp | 9 +- .../block_universal_gemm_as_aquant_bs_cr.hpp | 16 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 439 +++++++++++ .../kernel/gemm_bquant_kernel.hpp | 679 ++++++++++++++++++ .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 2 +- .../gemm_bquant_pipeline_ag_bg_cr_base.hpp | 53 ++ .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 93 +++ .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 475 ++++++++++++ .../pipeline/gemm_group_quant_utils.hpp | 54 +- ...em.hpp => gemm_quant_pipeline_problem.hpp} | 103 +++ ..._traits.hpp => tile_gemm_quant_traits.hpp} | 29 + 20 files changed, 2471 insertions(+), 26 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/run_gemm_bquant_example.inc create mode 100644 include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/kernel/gemm_bquant_kernel.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp rename include/ck_tile/ops/gemm_group_quant/pipeline/{gemm_aquant_pipeline_problem.hpp => gemm_quant_pipeline_problem.hpp} (53%) rename include/ck_tile/ops/gemm_group_quant/pipeline/{tile_gemm_aquant_traits.hpp => tile_gemm_quant_traits.hpp} (52%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 76fb46cdd9..8ae97b3d61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added int8 support for CK_TILE GEMM. * Added support for elementwise kernel. * Added benchmarking support for tile engine GEMM Multi D. +* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. ### Optimized diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 914fdac0e4..12cf874c73 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -8,9 +8,8 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp) target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - - add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp) - target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp) + target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index fc905790f1..6d6aec28c8 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -10,6 +10,7 @@ mkdir build && cd build ../script/cmake-ck-dev.sh ../ # The aquant pipeline method on the gemm calculation make tile_example_gemm_aquant_basic -j +make tile_example_gemm_bquant_basic -j ``` This will result in an executable `build/bin/tile_example_gemm_aquant_basic` diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp index 744c844040..d5a38fe754 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -27,8 +27,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr int kBlockPerCu = 1; - static_assert(std::is_same_v); constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; @@ -139,7 +137,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; @@ -207,7 +205,7 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::fp8_t, ck_tile::half_t, ck_tile::fp8_t>{}); - return run_gemm_example_prec_type, TypeConfig, 128>( + return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } else if(data_type == "i4bf8") @@ -216,7 +214,7 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::bf8_t, ck_tile::half_t, ck_tile::bf8_t>{}); - return run_gemm_example_prec_type, TypeConfig, 128>( + return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } else diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp new file mode 100644 index 0000000000..991c4841e4 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" + +template +float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::stream_config& s) +{ + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + + static_assert(std::is_same_v); + + constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; + + constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; + + constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = ck_tile::TileGemmBQuantTraits; + + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + constexpr bool transposed_warp_gemm = false; + + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using CodegenPipelineProblem = + ck_tile::GemmBQuantPipelineProblem; + using CodegenGemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; + using Kernel = + ck_tile::BQuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(args.k_batch != 1) + { + throw std::runtime_error("split-k is not supported yet!"); + } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);; +} + +#include "run_gemm_bquant_example.inc" + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for B."); + } + + return 0; +} + +template