From baf244000eb3d9215cbc1245ee1a63fcd8681352 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Wed, 23 Jul 2025 01:10:16 -0600 Subject: [PATCH] ck_tile kernel for gemm with groupwise quantized A tensor (#2473) * ck_tile kernel for gemm with groupwise quantized A or B tensor. This change introduces new pipelines with Intrawave scheduler and block gemm primitives that loads the scale tensor to registers to perform dequantization post MFMA on C tensor in registers. Scale tensor data, AQ/BQ is spliced across threads in registers and not stored in LDS. Current support is for the following combinations, but it should be fairly straightforward to extend support to more formats. 1. fp8, fp8 -> f32 2. bf8, bf8 -> f32 3. i4, fp8 -> f32 4. i4, bf8 -> f32 Group size can go down to as low as K length of underlying WarpGemm primitive. For Gemm problems with quantized B tensor, this change also introduces preliminary support for flatmm pipeline which loads B tensor directly into registers. * [Block Scale Gemm] Only run gemm quant examples on __gfx94__ - Only run gemm quant examples on __gfx94__ for usage of `v_cvt_pk_fp8_f32` - Format the code * [Block Scale Gemm] Remove Bquant Gemm BlockScale This cleanup is in preparation for future development of bquant. By isolating Aquant-related code, we can streamline the codebase and make it easier to add and maintain bquant functionality in subsequent updates. * [Block Scale Gemm] Format code with clang-format-12 The latest clang-format (v19) in ROCm 7.0 generate different result than clang-format-12 which is used in CK CI. Format code with clang-format-12 for consistency. * [Block Scale Gemm] Split the k direction loop - Split the k direction loop in block_universal_gemm_as_quant_bs_cr.hpp to make the logic clearer. - Disable C transposition. * [Block Scale Gemm] Move block scale gemm example to 38_block_scale_gemm * [Block Scale Gemm] Update copyright * test * Add TailHandler * Move TileDistributionEncodingPatternAQ * Refactor * refactor * fix bug * fix bug * help solve the PR comment * Format the code * [Block Scale Gemm] Add unit tests * [Block Scale Gemm] Add support to 16x16x32 MFMA - Add support to 16x16x32 MFMA - Fix a bug when exchange data crossing lanes --------- Co-authored-by: Vijay Krishnamoorthy Co-authored-by: Cong MA Co-authored-by: ThomasNing [ROCm/composable_kernel commit: e62710e461d64f2740eaf46ba672d3173b7f17d1] --- example/ck_tile/03_gemm/gemm_utils.hpp | 14 +- .../38_block_scale_gemm/CMakeLists.txt | 13 + example/ck_tile/38_block_scale_gemm/README.md | 35 + .../38_block_scale_gemm/gemm_aquant_basic.cpp | 226 ++++++ .../38_block_scale_gemm/gemm_utils.hpp | 675 +++++++++++++++++ .../run_gemm_aquant_example.inc | 259 +++++++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/core/numeric/pk_int4.hpp | 18 + include/ck_tile/host/fill.hpp | 55 ++ .../ck_tile/host/reference/reference_gemm.hpp | 104 +++ .../unary_element_wise_operation.hpp | 90 +++ .../ops/epilogue/cshuffle_epilogue.hpp | 4 +- .../ops/flatmm/pipeline/tile_flatmm_shape.hpp | 3 + .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 6 +- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 1 + .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 7 +- include/ck_tile/ops/gemm_group_quant.hpp | 12 + .../block_universal_gemm_as_aquant_bs_cr.hpp | 489 +++++++++++++ .../kernel/gemm_aquant_kernel.hpp | 679 +++++++++++++++++ .../gemm_aquant_pipeline_ag_bg_cr_base.hpp | 53 ++ .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 93 +++ .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 476 ++++++++++++ .../pipeline/gemm_aquant_pipeline_problem.hpp | 121 ++++ .../pipeline/gemm_group_quant_utils.hpp | 95 +++ .../pipeline/tile_gemm_aquant_traits.hpp | 34 + test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/gemm_block_scale/CMakeLists.txt | 19 + .../test_gemm_aquant_basic_bf8.cpp | 6 + .../test_gemm_aquant_basic_fp8.cpp | 6 + .../test_gemm_aquant_basic_i4bf8.cpp | 6 + .../test_gemm_aquant_basic_i4f32bf8.cpp | 6 + .../test_gemm_aquant_basic_i4f32fp8.cpp | 6 + .../test_gemm_aquant_basic_i4fp8.cpp | 6 + .../test_gemm_aquant_utils.hpp | 681 ++++++++++++++++++ .../test_run_gemm_aquant_example.inc | 577 +++++++++++++++ 35 files changed, 4864 insertions(+), 13 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/CMakeLists.txt create mode 100644 example/ck_tile/38_block_scale_gemm/README.md create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_utils.hpp create mode 100644 example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc create mode 100644 include/ck_tile/ops/gemm_group_quant.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp create mode 100644 include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp create mode 100644 test/ck_tile/gemm_block_scale/CMakeLists.txt create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 7a9b5afaa2..24f64994cf 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -114,16 +114,16 @@ template struct GemmConfigComputeV3 : public GemmConfigBase { // Compute V3 only support Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt new file mode 100644 index 0000000000..bdcb6f50bd --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -0,0 +1,13 @@ +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") + add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp) + target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile quant gemm tests for current target") +endif() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md new file mode 100644 index 0000000000..742a88dee7 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -0,0 +1,35 @@ +# GEMM Matrix Multiplication + +This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The aquant pipeline method on the gemm calculation +make tile_example_gemm_aquant_basic -j +``` +This will result in an executable `build/bin/tile_example_gemm_aquant_basic` + +## example +``` +args: + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) + -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp new file mode 100644 index 0000000000..a1ed3c4920 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" + +template +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) +{ + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + static_assert(std::is_same_v); + + constexpr ck_tile::index_t M_Tile = 16; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 256; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 32; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmAQuantTraits; + + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + constexpr bool transposed_warp_gemm = false; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using CodegenPipelineProblem = + ck_tile::GemmAQuantPipelineProblem; + using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; + using Kernel = + ck_tile::AQuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(args.k_batch != 1) + { + throw std::runtime_error("split-k is not supported yet!"); + } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +#include "run_gemm_aquant_example.inc" + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for A."); + } + + return 0; +} + +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(data_type == "fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4fp8") + { + using TypeConfig = decltype( + GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4bf8") + { + using TypeConfig = decltype( + GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32bf8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp new file mode 100644 index 0000000000..35e80ddb89 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -0,0 +1,675 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_group_quant.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE 5 + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(__gfx950__) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(__gfx950__) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct GemmConfigPreshufle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshufle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + +template +struct GemmQuantTypeConfig +{ + using ADataType = ADataType_; + using QDataType = QDataType_; + using BDataType = BDataType_; + using AccDataType = float; + using CDataType = CDataType_; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using QDataType = float; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("aq_layout", "R", "Aq tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent") + .insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc new file mode 100644 index 0000000000..9bdef9755b --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& aq_m_aqk_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t AQK, + ck_tile::index_t stride_A, + ck_tile::index_t stride_AQ, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::AQuantGemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.QK = AQK; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + args.stride_AQ = stride_AQ; + + float ave_time = gemm_calc_aquant( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK + + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B + << " StrideC =" << stride_C << " A_Layout =" << ALayout::name + << " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name + << " A_Type = " << DataTypeTraits::name + << " AQ_Type = " << DataTypeTraits::name + << " B_Type = " << DataTypeTraits::name + << " Acc_Type = " << DataTypeTraits::name + << " C_Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +template +int run_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using ADataType = typename TypeConfig::ADataType; + using AQDataType = typename TypeConfig::QDataType; + using BDataType = typename TypeConfig::BDataType; + using AccDataType = typename TypeConfig::AccDataType; + using CDataType = typename TypeConfig::CDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + if(K % QuantGroupSize != 0) + { + throw std::runtime_error("K must be aligned with QuantGroupSize"); + } + + ck_tile::index_t AQK = K / QuantGroupSize; + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor aq_m_aqk( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution fill_seed(0, 500); + + if(init_method == 0) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else if(init_method == 1) + { + std::cout << "Monotonic initialization is not supported." << std::endl; + return 0; + } + else if(init_method == 2) + { + ck_tile::FillConstant{static_cast(0x22)}(a_m_k); + ck_tile::FillConstant{static_cast(0.5f)}(aq_m_aqk); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + } + else + { + a_m_k.SetZero(); + aq_m_aqk.SetZero(); + b_k_n.SetZero(); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm(a_m_k_dev_buf, + aq_m_aqk_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + AQK, + stride_A, + stride_AQ, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + if(!pass) + { + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl; + return false; + } + + return pass; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 8989060842..db5cc71888 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -23,3 +23,4 @@ add_subdirectory(20_grouped_convolution) add_subdirectory(35_batched_transpose) add_subdirectory(36_copy) add_subdirectory(37_transpose) +add_subdirectory(38_block_scale_gemm) diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 541093e337..ba8b87a9b8 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -116,6 +116,24 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t& x) +{ + uint8_t x_u8 = ck_tile::bit_cast(x); + + float x_l = ((x_u8 & 0x0f) >> 0); + float x_h = ((x_u8 & 0xf0) >> 4); + + x_l = x_l > 7 ? x_l - 16 : x_l; + x_h = x_l > 7 ? x_l - 16 : x_l; + +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + fp32x2_t res = {x_h, x_l}; +#elif + fp32x2_t res = {x_l, x_h}; +#endif + return res; +} + CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 4a359e031f..9b31a7889d 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -92,6 +93,60 @@ struct FillUniformDistribution } }; +template <> +struct FillUniformDistribution +{ + float a_{-8.f}; // same type as primary template so that + // `FillUniformDistribution{-5.0f, 5.0f}` works for all types + float b_{7.f}; + std::optional seed_{11939}; + template + void operator()(ForwardIter first, ForwardIter last) const + { + if(a_ < -8.0f || b_ > 7.0f) + { + throw std::runtime_error( + "a_ or b_ of FillUniformDistribution is out of range."); + } + + int min_value = static_cast(a_); + int max_value = static_cast(b_); + constexpr auto int4_array = std::array{0x88, + 0x99, + 0xaa, + 0xbb, + 0xcc, + 0xdd, + 0xee, + 0xff, + 0x00, + 0x11, + 0x22, + 0x33, + 0x44, + 0x55, + 0x66, + 0x77}; + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_int_distribution dis(0, max_value - min_value + 1); + while(first != last) + { + int randomInt = dis(gen); + *first = int4_array[randomInt + (min_value + 8)]; + ++first; + } + } + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + namespace impl { // clang-format off diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index c88deaec01..70ca44170e 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -11,6 +11,110 @@ namespace ck_tile { +template +CK_TILE_HOST void reference_gemm_quant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mn = [&](auto m, auto n) { + AccDataType v_acc = 0, v_block_acc = 0; + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v || + std::is_same_v); + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + } + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else if constexpr(std::is_same_v) + { + v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n))); + } + else + { + v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + } + v_block_acc += v_a * v_b; + + // Apply group dequant scale + if((k + 1) % QuantGroupSize == 0) + { + float scale = 0.f; + index_t outer_dim = (aquant) ? m : k / QuantGroupSize; + index_t inner_dim = (aquant) ? k / QuantGroupSize : n; + + if constexpr(std::is_same_v) + { + scale = q(outer_dim, inner_dim); + } + else if constexpr(std::is_same_v) + { + scale = fp8_to_float_raw(q(outer_dim, inner_dim)); + } + else if constexpr(std::is_same_v) + { + scale = bf8_to_float_raw(q(outer_dim, inner_dim)); + } + else + { + static_assert(false, "Unexpected Q datatype."); + } + v_block_acc *= scale; + v_acc += v_block_acc; + v_block_acc = 0; + } + } + + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + }; + + make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); + std::cout << std::endl; +} + template (a), src_hi; + uint32_t fp8x4_lo, fp8x4_hi; + float tmp_0, tmp_1; + + asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n" + "v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n" + "v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n" + "v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n" + "v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n" + : [v_tmp_0] "+v"(tmp_0), + [v_tmp_1] "+v"(tmp_1), + [v_hi_src] "+v"(src_hi), + [v_dst_lo] "+v"(fp8x4_lo), + [v_dst_hi] "+v"(fp8x4_hi), + [v_src] "+v"(src) + :); + + return bit_cast(((static_cast(fp8x4_hi) << 32) | fp8x4_lo)); +} + +CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src) +{ + float res; + asm volatile("v_cvt_f32_fp8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src)); + return res; +} + +CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src) +{ + float res; + asm volatile("v_cvt_f32_bf8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src)); + return res; +} + +CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a) +{ + uint32_t src = static_cast(a), src_hi; + uint32_t bf8x4_lo, bf8x4_hi; + float tmp_0, tmp_1; + + asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n" + "v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n" + "v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n" + "v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n" + "v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n" + : [v_tmp_0] "+v"(tmp_0), + [v_tmp_1] "+v"(tmp_1), + [v_hi_src] "+v"(src_hi), + [v_dst_lo] "+v"(bf8x4_lo), + [v_dst_hi] "+v"(bf8x4_hi), + [v_src] "+v"(src) + :); + + return bit_cast(((static_cast(bf8x4_hi) << 32) | bf8x4_lo)); +} + struct PassThroughPack8 { template @@ -126,6 +206,16 @@ struct PassThroughPack8 y.lo = i4_to_bhalf4(bit_cast(x)); y.hi = i4_to_bhalf4(bit_cast(x) >> 16); } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const + { + y = amd_assembly_i4_to_fp8x8(bit_cast(x)); + } + + CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const + { + y = amd_assembly_i4_to_bf8x8(bit_cast(x)); + } constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index bf58544259..7ae63e17a7 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -69,6 +69,8 @@ struct CShuffleEpilogue using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; + using ATypeToUse = + std::conditional_t, BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; @@ -201,7 +203,7 @@ struct CShuffleEpilogue static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); - using WG = WarpGemmMfmaDispatcher [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); return concat('_', "pipeline_AgBgCrCompV3", - concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize), - concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', WaveNumM, WaveNumN), concat('x', kPadM, kPadN, kPadK)); // clang-format on } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 27a81ff090..97fab489ab 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -37,6 +37,7 @@ struct WarpGemmAtrributeMfma static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK; static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t kCMLane = Impl::kCMLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index f9d50ed35e..38fd0d408b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -11,9 +11,10 @@ struct WarpGemmImpl { using WarpGemmAttribute = remove_cvref_t; - static constexpr index_t kM = WarpGemmAttribute::kM; - static constexpr index_t kN = WarpGemmAttribute::kN; - static constexpr index_t kK = WarpGemmAttribute::kK; + static constexpr index_t kM = WarpGemmAttribute::kM; + static constexpr index_t kN = WarpGemmAttribute::kN; + static constexpr index_t kK = WarpGemmAttribute::kK; + static constexpr index_t kCMLane = WarpGemmAttribute::kCMLane; /// @brief The number of elements in K dimension processed by single thread in wavefront. /// /// @note Note that WarpGemm may run MFMA instruction multiple times (on different K). diff --git a/include/ck_tile/ops/gemm_group_quant.hpp b/include/ck_tile/ops/gemm_group_quant.hpp new file mode 100644 index 0000000000..0041c658b4 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" +#include "ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp" diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp new file mode 100644 index 0000000000..c1ff6a356e --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -0,0 +1,489 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/elementwise.hpp" + +namespace ck_tile { + +template +struct BlockGemmQuantBase +{ + using AQDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + static constexpr index_t UnaryOpSize = UnaryOpSize_; + template + CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale) + { + float scale_reg_f = 0.f; + if constexpr(std::is_same_v) + { + scale_reg_f = + ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + } + else if constexpr(std::is_same_v) + { + scale_reg_f = + ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + } + else if constexpr(std::is_same_v) + { + scale_reg_f = ck_tile::bit_cast(scale); + } + else + { + static_assert(false, "AQDataType must be float, fp8_t or bf8_t."); + } + return scale_reg_f; + } + + template + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, + const WarpWindow& warp_window) + { + const element_wise::PassThroughPack8 elementwise_op{}; + + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); + + using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), + in_dstr_tensors.get_thread_buffer().template get_as()[i]); + }); + } +}; + +// A is block window on shared memory +// AQ (scale tensor) is block distributed tensor. +// Consecutive kQuantGroupSize elements of A are quantized with a separate scale. +// B is block window on shared memory +// C is block distributed tensor +template +struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase +{ + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr auto Scheduler = Problem::Scheduler; + + // Threadblock GEMM tile size + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t AQPerBlock = KPerBlock / kQuantGroupSize; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + // number of warps along M and N for threadblock's GEMM problem size + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + using I0 = number<0>; + using I1 = number<1>; + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), + "Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"); + static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), + "Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"); + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), + "Error! WarpGemm's M is not consisten with BlockGemmShape!"); + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), + "Error! WarpGemm's N is not consisten with BlockGemmShape!"); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr index_t QScalesPerBlockRow = + (KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize; + static constexpr index_t QScalesPerWarpGemmRow = + (WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize; + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + static_assert(kQuantGroupSize % WarpGemm::kK == 0, + "Error! WarpGemm::kK should be a multiple of kQuantGroupSize"); + static_assert(QScalesPerWarpGemmRow == 1, + "Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK"); + static_assert(KIterPerWarp % QScalesPerBlockRow == 0, + "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); + + static_assert(KPerBlock / kQuantGroupSize > 0, + "Error! Each row of blockgemm should have a separate scale"); + + static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, + "Error! Warps should cover all Block tile!"); + static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock, + "Error! Warps should cover all Block tile!"); + + // Currently tested combinations (A, AQ, B) + // 1. fp8, fp32, fp8 -> f32 + // 2. bf8, fp32, bf8 -> f32 + // 3. i4, (fp8/fp32) fp8 -> f32 + // 4. i4, (fp8/fp32) bf8 -> f32 + static_assert( + (std::is_same_v || std::is_same_v || + std::is_same_v< + ADataType, + bf8_t>)&&(std::is_same_v || + std::is_same_v< + BDataType, + bf8_t>)&&(std::is_same_v || + std::is_same_v || + std::is_same_v< + AQDataType, + ck_tile::bf8_t>)&&(std::is_same_v || + std::is_same_v)&&std:: + is_same_v); + + static constexpr index_t InterWaveSchedulingMacClusters = 1; + + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + }; + + public: + using Traits = GemmTraits_; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using Base = BlockGemmQuantBase; + + using WarpGemm = remove_cvref_t; + + static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; + static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; + static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; + + static constexpr index_t MWarp = Traits::MWarp; + static constexpr index_t NWarp = Traits::NWarp; + + static constexpr auto Scheduler = Traits::Scheduler; + static constexpr uint8_t kA_cvt_scale = std::is_same_v ? 16 : 1; + static constexpr uint8_t kB_cvt_scale = std::is_same_v ? 16 : 1; + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static_assert(std::is_same_v); + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using I0 = number<0>; + using I1 = number<1>; + + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + private: + template + struct BlockGemmImpl + { + }; + + template + struct BlockGemmImpl + { + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + BLdsTile b_warp_tile_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + if constexpr(std::is_same_v) + { + static_assert(std::is_same_v || + std::is_same_v); + Base::load_interleaved_pk_type(a_warp_tile_, a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + if constexpr(std::is_same_v) + { + static_assert(std::is_same_v || + std::is_same_v); + Base::load_interleaved_pk_type(b_warp_tile_, b_block_window); + } + else + { + load_tile(b_warp_tile_, b_block_window); + } + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + [[maybe_unused]] ASmemBlockWindow& a_block_window, + [[maybe_unused]] BSmemBlockWindow& b_block_window) + { + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as correspoinding " + "C block tensor data type!"); + + // hot loop: + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + CWarpTensor c_warp_tensor; + + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + }); + + // Need to multiply aquant with accumulated C + // + // The accumulated C tile has the standard distribution. For example + // lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], + // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0], + // [26,0], [27,0]. + // + // These elements are in different rows, need to get the scale value + // for the corresponding row. + // Based on aquant's tile distribution, it can be inferred which + // lane holds the relevant scale. For example, the scales corresponding + // to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9, + // 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively. + // + // These scales can be obtained using __builtin_amdgcn_ds_bpermute. + + // MIters per warp + constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM; + + // Reg block offset based on mIter + constexpr index_t reg_block_offset = + ((mIter / mIters_per_warp) * Traits::AQPerBlock); + + constexpr index_t lane_base_offset = + (mIter % mIters_per_warp) * WarpGemm::kM; + + // Scale tensor offset along K + constexpr index_t src_reg_offset = reg_block_offset + kQScale; + + constexpr uint32_t kTileRows = 4; + constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; + + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) { + // Multiply by 4 because output is stored in tiles of 4 + // x CNLane + constexpr uint32_t row_base = + ((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) + + ((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane); + + constexpr uint32_t reg_offset_for_row_data = c_row / WarpGemm::kCMLane; + + // Lane index to source scale from + uint32_t src_lane_idx = lane_base_offset + row_base + + (__lane_id() / WarpGemm::kN * kTileRows); + + // Directly index into thread buffer corresponding to + // desired row coefficient + auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + // Pull scale data across lanes + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); + + c_block_tensor + .get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] += + (c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] * + scale_reg_f * kA_cvt_scale * kB_cvt_scale); + }); + }); + }); + }); + } + }; + + public: + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_(c_block_tensor, aq_block_tensor, a_block_window, b_block_window); + } + + private: + BlockGemmImpl block_gemm_impl_{}; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp new file mode 100644 index 0000000000..b1f89fe2e2 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp @@ -0,0 +1,679 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +struct AQuantGemmProblem +{ + CK_TILE_HOST AQuantGemmProblem() = default; + CK_TILE_HOST AQuantGemmProblem(index_t M_, + index_t N_, + index_t K_, + index_t QK_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + index_t stride_AQ_) + : M(M_), + N(N_), + K(K_), + QK(QK_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_C(stride_C_), + stride_AQ(stride_AQ_) + { + } + + index_t M; + index_t N; + index_t K; + index_t QK; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t stride_AQ; +}; + +struct AQuantGemmHostArgs : public AQuantGemmProblem +{ + CK_TILE_HOST AQuantGemmHostArgs() = default; + CK_TILE_HOST AQuantGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + const void* aq_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t QK_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + index_t stride_AQ_) + : AQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_AQ_), + a_ptr(a_ptr_), + b_ptr(b_ptr_), + aq_ptr(aq_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + const void* aq_ptr; + void* c_ptr; + index_t k_batch; +}; + +struct AQuantGemmKernelArgs +{ + const void* a_ptr; + const void* b_ptr; + const void* aq_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t QK; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t stride_AQ; + index_t k_batch; +}; + +template +struct AQuantGemmKernel +{ + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr AQuantGemmKernelArgs + MakeKernelArgs(const AQuantGemmHostArgs& hostArgs) + { + return AQuantGemmKernelArgs{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.aq_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.QK, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C, + hostArgs.stride_AQ, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + + if constexpr(std::is_same_v) + { + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const AQuantGemmKernelArgs& kargs) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + + static_assert(std::is_same_v); + if(kargs.QK % GemmPipeline::GetVectorSizeAQ() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + return false; + } + + if constexpr(std::is_same_v) + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + return false; + } + if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + return false; + } + } + else + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + return false; + } + if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + } + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + return true; + } + + template + CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + CDataType* c_ptr, + const AQuantGemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + const auto& aq_tensor_view = [&]() { + static_assert(std::is_same_v); + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.QK), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + }(); + + const auto& b_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& aq_pad_view = [&]() { + const auto& aq_tensor_view = views.at(I1); + static_assert(std::is_same_v); + return pad_tensor_view( + aq_tensor_view, + make_tuple(number{}, + number{}), + // TODO: Add support for padding. + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // TODO vector write in for C in ColMajor + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& aq_pad_view = views.at(I1); + const auto& b_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& aq_block_window = [&]() { + static_assert(std::is_same_v); + return make_tile_window( + aq_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + }(); + + const auto& b_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } + }(); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param aq_ptr input AQ pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + * @tparam DstInMemOp Destination memory operation (default: set). + */ + template + CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const AQuantGemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( + a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& aq_block_window = gemm_tile_windows.at(I1); + const auto& b_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const + { + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + // options + const ADataType* a_ptr = static_cast(kargs.a_ptr); + const BDataType* b_ptr = static_cast(kargs.b_ptr); + const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); + CDataType* c_ptr = static_cast(kargs.c_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + assert(kargs.k_batch == 1); + RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp new file mode 100644 index 0000000000..1356d7e222 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +template +struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +{ + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + + using AQLayout = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize; + + static_assert(KPerBlock % QuantGroupSize == 0, + "KPerBlock must be a multiple of QuantGroupSize"); + + // Create DRAM tile window for AQ + template + CK_TILE_DEVICE constexpr auto + GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const + { + static_assert(std::is_same_v); + + using YPerTile = number; + using XPerTile = number; + + auto aq_copy_dram_window = + make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile(), XPerTile()), + aq_dram_block_window_tmp.get_window_origin(), + Policy::template MakeAQDramTileDistribution()); + return aq_copy_dram_window; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..83b61e23fc --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_group_quant_utils.hpp" + +namespace ck_tile { + +struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + using Base::ATileAccessPattern; + using Base::BTileAccessPattern; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() + { + using AQLayout = remove_cvref_t; + using AQDataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize; + + static_assert(std::is_same_v); + return GetAQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution() + { + using AQLayout = remove_cvref_t; + using BlockGemmShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize; + constexpr index_t VecLoadSize = GetVectorSizeAQ(); + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + + static_assert(std::is_same_v); + using TileEncodingPattern = TileDistributionEncodingPatternAQ; + + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of kQuantGroupSize!"); + + using WarpGemm = WarpGemmMfmaDispatcher; + static_assert(std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return AQuantBlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp new file mode 100644 index 0000000000..9fb26eb4e0 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -0,0 +1,476 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BaseAQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +{ + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + if(has_hot_loop) + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + else + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + } +}; + +template +struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t AQPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetVectorSizeAQ() + { + return Policy::template GetVectorSizeAQ(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "aquant_pipeline_AgBgCrCompV3", + concat('x', MPerBlock, NPerBlock, KPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK), + concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + constexpr index_t AQ_Buffer_Load_Inst_Num = + MPerBlock * KPerBlockAQ / (BlockSize * GetVectorSizeAQ()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", " + << "AQ vector size: " << GetVectorSizeAQ() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << ", " + << "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "QuantGroupSize: " << QuantGroupSize << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/AQ Dram block window should have the same data type as appropriate " + "([A|B|AQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_aq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); + static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + "Aq block window has incorrect lengths for defined AqLayout!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; + + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + + auto block_gemm = BlockGemm(); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + AQBlockTile aq_block_tile[2]; + int currIdx = 0; + + auto c_block_tile = block_gemm.MakeCBlockTile(); + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr AQDramTileWindowStep aq_dram_tile_window_step = + is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); + + // DRAM prefetch (global read 0) + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch( + aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + aq_copy_dram_window, + aq_dram_tile_window_step); + + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + + currIdx = (currIdx + 1) % 2; + + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + } + else + { + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + aq_copy_dram_window, + aq_dram_tile_window_step); + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + currIdx = (currIdx + 1) % 2; + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + } + return c_block_tile; + } + }; + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp new file mode 100644 index 0000000000..4cca30fd3b --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" + +#include + +namespace ck_tile { + +template +struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase +{ + using Base = GemmPipelineProblemBase; + + using Traits = typename Base::Traits; + + using typename Base::ADataType; + using typename Base::BDataType; + using typename Base::CDataType; + using typename Base::ComputeDataType; + using AQDataType = remove_cvref_t; + + using BlockGemmShape = typename Base::BlockGemmShape; + + using typename Base::ALayout; + using typename Base::BLayout; + using typename Base::CLayout; + + static constexpr bool TransposeC = false; + + using Base::kBlockSize; + + using Base::kPadK; + using Base::kPadM; + using Base::kPadN; + + using Base::DoubleSmemBuffer; + using Base::VectorLoadSize; + + using AQLayout = remove_cvref_t; + + static constexpr uint32_t kQuantGroupSize = QuantGroupSize_; + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + + static_assert(BlockGemmShape::kK % kQuantGroupSize == 0); + static_assert(Scheduler == GemmPipelineScheduler::Intrawave); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm_aquant_problem", + concat('x', VectorLoadSize, kBlockSize), + concat('x', kPadM, kPadN, kPadK), + Scheduler, + "QuantGroupSize", + kQuantGroupSize); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ() + { + static_assert(std::is_same_v); + return VectorLoadSize / sizeof(AQDataType); + } + + static constexpr index_t VectorSizeAQ = []() { + static_assert(std::is_same_v); + return kPadK ? 1 : GetAlignmentAQ(); + }(); +}; + +template +using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp new file mode 100644 index 0000000000..c018314ab7 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE static constexpr auto GetAQGlobalVectorLoadSize() +{ + using I1 = number<1>; + constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t BlockSize = Problem::kBlockSize; + + // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps + constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps); + constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + + // Define vector load candidates in descending order of priority + constexpr std::array candidates{ + PackedSize * 32 / sizeof(DataType), + PackedSize * 16 / sizeof(DataType), + PackedSize * 8 / sizeof(DataType), + PackedSize * 4 / sizeof(DataType), + PackedSize * 2 / sizeof(DataType), + }; + + for(const auto vec_size : candidates) + { + if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0) + continue; + bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) && + (elements_per_thread % vec_size == 0) && vec_size != candidates[4]; + if(is_valid) + { + return vec_size; + } + } + return PackedSize; // Absolute fallback +} + +// AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across +// threads. Post mfma scales are shuffled across threads in the warp and applied to +// accum registers. +template +struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern +{ + // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! + static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t num_warps = BlockSize / get_warp_size(); + + static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM); + + static_assert(num_warps == MWarps * NWarps * KWarps); + + // KWarps > 1 isn't supported + static_assert(KWarps == 1); + + // # of elements per thread + static constexpr index_t X = XPerTile; + + static constexpr index_t Y0 = 1; + static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1; + static constexpr index_t Y2 = MWarps; + static constexpr index_t Y3 = WarpGemm::kM; + static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp."); + static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, + "Y0, Y1, Y2, Y3 must cover the blocktile along Y."); + + CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 1>>, + tuple, sequence<0, 3>>, + sequence<1, 2>, + sequence<1, 0>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp new file mode 100644 index 0000000000..4972badb3f --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmAQuantTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + static constexpr int _VectorSize = 16; + + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; + using AQLayout = AQLayout_; + + static constexpr bool UseStructuredSparsity = false; + static constexpr index_t NumWaveGroups = 1; +}; + +} // namespace ck_tile diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 3e5a3034cd..8f3fbd52c5 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -17,3 +17,4 @@ add_subdirectory(topk_softmax) add_subdirectory(add_rmsnorm2d_rdquant) # add_subdirectory(layernorm2d) # add_subdirectory(rmsnorm2d) +add_subdirectory(gemm_block_scale) diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt new file mode 100644 index 0000000000..847ab88644 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -0,0 +1,19 @@ +set(TEST_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND TEST_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") + set(TEST_GEMM_NAME test_tile_gemm_aquant_basic) + set(QUANT_TYPES fp8 bf8 i4fp8 i4bf8 i4f32fp8 i4f32bf8) + + foreach(QUANT_TYPE ${QUANT_TYPES}) + add_gtest_executable(${TEST_GEMM_NAME}_${QUANT_TYPE} test_gemm_aquant_basic_${QUANT_TYPE}.cpp) + target_compile_options(${TEST_GEMM_NAME}_${QUANT_TYPE} PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + endforeach() + +else() + message(DEBUG "Skipping ck_tile quant gemm tests for current target") +endif() diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp new file mode 100644 index 0000000000..9c4277d879 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("bf8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp new file mode 100644 index 0000000000..b0cf55be6f --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("fp8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp new file mode 100644 index 0000000000..fd80bf2b06 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("i4bf8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp new file mode 100644 index 0000000000..fe8c9c5000 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("i4f32bf8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp new file mode 100644 index 0000000000..a319d9c2ad --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("i4f32fp8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp new file mode 100644 index 0000000000..ceb8760435 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("i4fp8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp new file mode 100644 index 0000000000..40f6712ef9 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp @@ -0,0 +1,681 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_group_quant.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE 5 + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(__gfx950__) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(__gfx950__) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +class ArgumentsNotSupportedException : public std::logic_error +{ + public: + explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {} +}; + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct GemmConfigPreshufle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshufle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + +template +struct GemmQuantTypeConfig +{ + using ADataType = ADataType_; + using QDataType = QDataType_; + using BDataType = BDataType_; + using AccDataType = float; + using CDataType = CDataType_; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using QDataType = float; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("aq_layout", "R", "Aq tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent") + .insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc new file mode 100644 index 0000000000..f410b58053 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc @@ -0,0 +1,577 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "test_gemm_aquant_utils.hpp" + +template +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) +{ + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + static_assert(std::is_same_v); + + constexpr ck_tile::index_t M_Tile = 16; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 256; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 32; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmAQuantTraits; + + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + constexpr bool transposed_warp_gemm = false; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using CodegenPipelineProblem = + ck_tile::GemmAQuantPipelineProblem; + using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; + using Kernel = + ck_tile::AQuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(args.k_batch != 1) + { + throw std::runtime_error("split-k is not supported yet!"); + } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& aq_m_aqk_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t AQK, + ck_tile::index_t stride_A, + ck_tile::index_t stride_AQ, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::AQuantGemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.QK = AQK; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + args.stride_AQ = stride_AQ; + + float ave_time = gemm_calc_aquant( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK + + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B + << " StrideC =" << stride_C << " A_Layout =" << ALayout::name + << " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name + << " A_Type = " << DataTypeTraits::name + << " AQ_Type = " << DataTypeTraits::name + << " B_Type = " << DataTypeTraits::name + << " Acc_Type = " << DataTypeTraits::name + << " C_Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +template +bool run_gemm_test_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + using ADataType = typename TypeConfig::ADataType; + using AQDataType = typename TypeConfig::QDataType; + using BDataType = typename TypeConfig::BDataType; + using AccDataType = typename TypeConfig::AccDataType; + using CDataType = typename TypeConfig::CDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + if(K % QuantGroupSize != 0) + { + throw std::runtime_error("K must be aligned with QuantGroupSize"); + } + + ck_tile::index_t AQK = K / QuantGroupSize; + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor aq_m_aqk( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution fill_seed(0, 500); + + if(init_method == 0) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else if(init_method == 1) + { + std::cout << "Monotonic initialization is not supported." << std::endl; + return true; + } + else if(init_method == 2) + { + ck_tile::FillConstant{static_cast(0x22)}(a_m_k); + ck_tile::FillConstant{static_cast(0.5f)}(aq_m_aqk); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + } + else + { + a_m_k.SetZero(); + aq_m_aqk.SetZero(); + b_k_n.SetZero(); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm(a_m_k_dev_buf, + aq_m_aqk_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + AQK, + stride_A, + stride_AQ, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + if(!pass) + { + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl; + return false; + } + + return pass; +} + +template +bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for A."); + } + + return true; +} + +bool run_gemm_test(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(data_type == "fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4fp8") + { + using TypeConfig = decltype( + GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4bf8") + { + using TypeConfig = decltype( + GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32bf8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int run_gemm_combinations(std::string const& data_type) +{ + // Define possible values for each parameter + std::vector> mnk_values = {{ + "1", + "2048", + "5120", + }, + { + "2", + "2048", + "5120", + }, + { + "16", + "2048", + "5120", + }, + { + "17", + "2048", + "5120", + }, + { + "2047", + "5120", + "1024", + }, + { + "2048", + "5120", + "1024", + }}; + std::vector prec_values = {data_type}; + + // We'll store all our arguments as strings first + std::vector arg_strings = {"test_tile_gemm_aquant_basic", + "", // m placeholder + "", // n placeholder + "", // k placeholder + "", // prec placeholder + "-init=0", + "-v=1", + "-warmup=0", + "-repeat=1"}; + + // Create an array of const char pointers for argv + constexpr size_t ARG_COUNT = 9; + constexpr size_t ARG_MAX_LEN = 64; + char args[ARG_COUNT][ARG_MAX_LEN]; + char* argv[ARG_COUNT]; + + // Run all combinations + bool is_success = true; + for(const auto& mnk : mnk_values) + { + arg_strings[1] = "-m=" + mnk[0]; + arg_strings[2] = "-n=" + mnk[1]; + arg_strings[3] = "-k=" + mnk[2]; + + for(const auto& prec : prec_values) + { + arg_strings[4] = "-prec=" + prec; + + // Set up the argv array with pointers to the string data + for(size_t i = 0; i < ARG_COUNT; i++) + { + strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN); + argv[i] = args[i]; + } + + std::cout << "Arguments received: "; + for(size_t i = 1; i < ARG_COUNT; ++i) + { + std::cout << argv[i] << " "; + } + std::cout << std::endl; + + // Call the function with the current configuration + try + { + is_success = run_gemm_test(ARG_COUNT, argv) && is_success; + } + catch(const ArgumentsNotSupportedException& e) + { + std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n'; + // ArgumentsNotSupportedException is not an error. Do not change is_success + } + catch(const std::runtime_error& e) + { + std::cerr << "Caught runtime error: " << e.what() << '\n'; + is_success = false; + } + } + } + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +}