From fe92102baf0b032e209968e3448dafbc040c6cfb Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Fri, 31 Oct 2025 20:13:43 +0000 Subject: [PATCH] add some documentation and 2d block scale example --- .../38_block_scale_gemm/CMakeLists.txt | 3 + .../gemm_quant_2d_block.cpp | 442 ++++++++++++++++++ .../38_block_scale_gemm/gemm_utils.hpp | 19 +- .../gemm_bquant_pipeline_ag_bg_cr_base.hpp | 4 +- .../pipeline/gemm_group_quant_utils.hpp | 56 ++- 5 files changed, 507 insertions(+), 17 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_2d_block.cpp diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 7358d4d749..1c419b1c15 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -8,6 +8,9 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_executable(tile_example_gemm_quant_2d_block EXCLUDE_FROM_ALL gemm_quant_2d_block.cpp) + target_compile_options(tile_example_gemm_quant_2d_block PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_2d_block.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_2d_block.cpp new file mode 100644 index 0000000000..5638ead778 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_2d_block.cpp @@ -0,0 +1,442 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// This example demonstrates 2D block scale quantization (N×K) for BQuant +// using non-preshuffled configuration. +// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example +// This is currently done separately to avoid too verbose dispatching. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" + +template +float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) +{ + static_assert(std::is_same_v); + using ComputeDataType = std::conditional_t; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using GemmTraits = ck_tile::TileGemmQuantTraits; + + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + // This example only supports BQuant (no AQuant) + // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3 + using BaseGemmPipeline = std::conditional_t< + GemmConfig::PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; + + const ck_tile::index_t K_split = + (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr bool transpose_c = false; + + // row-col and tensor quants use the regular pipeline, A/B quants use their own + using PipelineProblem = std::conditional_t< + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant, + ck_tile::GemmRowColTensorQuantPipelineProblem, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>>; + + using GemmPipeline = std::conditional_t< + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant, + ck_tile::GemmPipelineAgBgCrCompV3, + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrMem, // memory pipeline hardcoded + // for aquant + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + GemmConfig::TiledMMAPermuteN>>; + using Kernel = + ck_tile::QuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(args.k_batch != 1) + { + throw std::runtime_error("split-k is not supported yet!"); + } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << PipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + float ave_time = 0; + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper + rotating_mem( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString( + hipMemsetAsync(args.c_ptr, + 0, + args.M * args.N * sizeof(typename TypeConfig::CDataType), + s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + + return ave_time; + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +#include "run_gemm_quant_example.inc" + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if((QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::RowColQuant) && + GemmConfig::PreshuffleB) + { + throw std::runtime_error( + "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); + } + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for A."); + } + + return 0; +} + +// Forward declaration for dispatch function +template