From 9ec8eac0798798823b8fa21cdcd04c2695881f0c Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Tue, 6 Jan 2026 21:12:56 +0000 Subject: [PATCH] Merge commit 'aaa35f0bbfa45dadc4380ddd6e0224668ddb97b4' into develop --- CHANGELOG.md | 1 + .../ck_tile/17_grouped_gemm/CMakeLists.txt | 3 +- .../17_grouped_gemm/abquant_grouped_gemm.cpp | 278 ++++++++ .../17_grouped_gemm/abquant_grouped_gemm.hpp | 171 +++++ .../run_grouped_gemm_abquant_example.inc | 604 ++++++++++++++++ .../gemm_abquant_quantgrouped.cpp | 60 ++ ...mm_bquant_quantgrouped_preshufflequant.cpp | 245 ++++++- .../run_gemm_quant_example.inc | 51 +- include/ck_tile/ops/gemm_quant.hpp | 3 + ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 282 ++++++++ .../block_universal_gemm_as_bs_bquant_cr.hpp | 1 + .../gemm_quant/kernel/gemm_quant_kernel.hpp | 645 ++---------------- .../kernel/grouped_gemm_quant_kernel.hpp | 17 +- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 5 +- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 15 +- .../pipeline/gemm_group_quant_utils.hpp | 164 ++++- ..._abquant_pipeline_ag_bg_cr_base_policy.hpp | 120 ++++ .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 611 +++++++++++++++++ .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 20 +- test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/gemm_block_scale/CMakeLists.txt | 6 + .../test_gemm_quant_abquant_preshuffle_2d.cpp | 44 ++ .../test_gemm_quant_fixtures.hpp | 12 +- .../grouped_gemm_abquant/CMakeLists.txt | 16 + .../test_grouped_gemm_abquant_1x128x128.cpp | 47 ++ .../test_grouped_gemm_abquant_1x1x128.cpp | 47 ++ .../test_grouped_gemm_abquant_ut_cases.inc | 87 +++ .../test_grouped_gemm_abquant_util.hpp | 530 ++++++++++++++ 28 files changed, 3387 insertions(+), 699 deletions(-) create mode 100644 example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp create mode 100644 example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp create mode 100644 example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc create mode 100644 include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp create mode 100755 include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp create mode 100644 include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp create mode 100644 test/ck_tile/grouped_gemm_abquant/CMakeLists.txt create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index b149a74df3..3280ad07dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## (Unreleased) Composable Kernel 1.3.0 ### Added +* Added preshuffleB support for abquant mode in blockscale GEMM. * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". * Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 9b51af22fe..0f0a0d8ba7 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -14,7 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") quant_grouped_gemm_bf8_rowcol.cpp quant_grouped_gemm_bf8_tensor.cpp ) - + add_executable(tile_example_abquant_grouped_gemm abquant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp) add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) @@ -25,4 +25,5 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_abquant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp new file mode 100644 index 0000000000..84da1e26da --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp @@ -0,0 +1,278 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/host.hpp" +#include "abquant_grouped_gemm.hpp" + +// Non-persistent grouped gemm for ABQuant +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +// Persistent grouped gemm tileloop for ABQuant +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); +} + +#include "run_grouped_gemm_abquant_example.inc" + +int main(int argc, char* argv[]) +{ + int result1 = run_abquant_grouped_gemm_example(argc, argv); + return result1; +} diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp new file mode 100644 index 0000000000..da8bd5514c --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp @@ -0,0 +1,171 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); +}; + +template +struct GemmQuantConfig; + +// ABQuant specialization for GemmQuantConfig +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert( + "stride_As", + "", + "Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs + // can be set to zero if + // Ms/Ns/Ks is not empty + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") + .insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.") + .insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") + .insert("c_layout", "R", "C tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.") + .insert("bquant_group_size", "1x1x128", "BQuant group size. 1x1x128 (default) or 1x128x128") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "abquant_grouped_gemm.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); +} + +// Forward declaration of the non-persistent version +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr); + +// Forward declaration of the tileloop version for persistent kernels +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc new file mode 100644 index 0000000000..bc5167439d --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc @@ -0,0 +1,604 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_abquant_gemm(int n_warmup, + int n_repeat, + int group_count, + const std::vector& args) +{ + // Workspace memory allocated to hold the gemm descriptions. + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(args)); + + float ave_time = 0; + + if constexpr(!GemmConfig::Persistent) + { + ave_time = grouped_gemm_abquant( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); + } + + return ave_time; +} + +template +int run_abquant_grouped_gemm_example_with_layouts( + int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + const BQLayout bq_layout = BQLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + + auto [result, arg_parser] = create_args(argc, argv); + + auto valid_input_data = [&](int group_count, const auto&... args) { + return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); + }; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + const int init_method = arg_parser.get_int("init"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector AQs; // dimension of AQ tensor is calculated from A tensor + std::vector BQs; // dimension of BQ tensor is calculated from B tensor + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); + std::vector stride_AQs = arg_parser.get_int_vec("stride_AQs"); + std::vector stride_BQs = arg_parser.get_int_vec("stride_BQs"); + + ck_tile::index_t AQK, BQK; + + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs)) + { + std::cout << "Please check the input data. Default values will be used." << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + stride_AQs.clear(); + stride_BQs.clear(); + + for(int i = 0; i < group_count; i++) + { + + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + // Let get_default_stride calculate based on layout + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + } + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + // For ABQuantGrouped, both A and B need quantization + static_assert(QuantMode == ck_tile::QuantType::ABQuantGrouped, + "This file only supports ABQuantGrouped mode"); + + AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / AQuantGroupSize + BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / BQuantGroupSize + if(K % AQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode"); + } + if(K % BQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode"); + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout)); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc + << " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl; + + if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(bq_tensors[i]); + } + else + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + } + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back( + std::make_unique(aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back( + std::make_unique(bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + kbatch, + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + float ave_time = invoke_abquant_gemm(warmup, repeat, group_count, gemm_descs); + + std::string op_name = "ABQuant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K; + + num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K + + sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N + + sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + if(validate) + { + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + // Reference implementation for ABQuantGrouped + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); + pass &= + ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + + if(arg_parser.get_int("json") == 1) + { + dump_grouped_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + group_count, + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass; +} + +template +int run_abquant_grouped_gemm_example_prec_type_with_bquant( + std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + using AQDataType = typename Types::AccDataType; + using BQDataType = typename Types::AccDataType; + using AQuantGroupSize = ck_tile::QuantGroupShape>; + + constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped; + + if(a_layout == "R" && b_layout == "C" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Col{}, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + +template +int run_abquant_grouped_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + std::string c_layout, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(bquant_group_size == "1x1x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else if(bquant_group_size == "1x128x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported BQuantGroupSize! Use 1x1x128 or 1x128x128."); + } +} + +template +int run_abquant_gemm_example_persistency(std::string a_layout, + std::string b_layout, + std::string c_layout, + bool persistent, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(persistent) + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } + else + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } +} + +int run_abquant_grouped_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string c_layout = arg_parser.get_str("c_layout"); + const std::string data_type = arg_parser.get_str("prec"); + bool persistent = arg_parser.get_bool("persistent"); + const std::string bquant_group_size = arg_parser.get_str("bquant_group_size"); + + if(data_type == "fp8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else if(data_type == "bf8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type configuration."); + } +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 4a90c07e05..155f19881e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -69,4 +69,64 @@ void abquant_quantgrouped_instance_factory( BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index e0e0a64416..62ca34b057 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -9,36 +9,194 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill; void bquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut) { - using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; lut[hash_multiple_strings( {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, @@ -47,10 +205,63 @@ void bquant_quantgrouped_preshufflequant_instance_factory( lut[hash_multiple_strings( {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index d8988be7b0..607c53d9af 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -74,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, ck_tile::BaseGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -145,26 +146,33 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmConfig::Scheduler, has_hot_loop_v, tail_number_v>>>>; + using AQuantPipeline = + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>; + + using BQuantPipeline = std::conditional_t< + GemmConfig::PreshuffleB, + ck_tile::WPQuantBPipelineAgBgCrV2, + std::conditional_t< + std::is_same_v, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + + using ABQuantPipeline = + std::conditional_t, + ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; using GemmPipeline = std::conditional_t< QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant, ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - std::conditional_t, - ck_tile::AQuantGemmPipelineAgBgCrMem>, - std::conditional_t< - QuantMode == ck_tile::QuantType::ABQuantGrouped, - ck_tile::ABQuantGemmPipelineAgBgCrCompV3, - std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::WPQuantBPipelineAgBgCrV2, - std::conditional_t< - std::is_same_v, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>>>; + std::conditional_t>>; constexpr bool TiledPermuteN = (BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; @@ -532,7 +540,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { @@ -908,8 +916,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - if((QuantMode == ck_tile::QuantType::ABQuantGrouped || - QuantMode == ck_tile::QuantType::AQuantGrouped || + if((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::RowColQuant || std::is_same_v) && GemmConfig::PreshuffleB) @@ -938,7 +945,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::ABQuantGrouped) && - !GemmConfig::PreshuffleQuant) + !GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB) { if(a_layout == "R" && b_layout == "R") { diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 1e4aece0d7..696de378aa 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" +#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" @@ -24,6 +25,8 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp new file mode 100644 index 0000000000..63a5151108 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -0,0 +1,282 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// BQ (scale tensor) is block distributed tensor. +// Consecutive QuantGroupSize elements of B are quantized with a separate scale. +// B is block window on block distributed tensor. +// C is block distributed tensor +template +struct BlockGemmWeightPreshuffleABQuantARegBRegCReg +{ + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr auto Scheduler = Problem::Scheduler; + + // Threadblock GEMM tile size + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN; + static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK; + static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + // number of warps along M and N for threadblock's GEMM problem size + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + using I0 = number<0>; + using I1 = number<1>; + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), + "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!"); + static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), + "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), + "Error! WarpGemm's M is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), + "Error! WarpGemm's N is not consistent with BlockGemmShape!"); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); + static constexpr index_t QScalesPerWarpGemmRow = + integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK); + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0, + "Error! WarpGemm::kK should be a multiple of QuantGroupSize"); + static_assert(QScalesPerWarpGemmRow == 1, + "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK"); + static_assert(KIterPerWarp % QScalesPerBlockRow == 0, + "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); + + static_assert(KPerBlock / BQuantGroupSize::kK > 0, + "Error! Each row of blockgemm should have a separate scale"); + + static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, + "Error! Warps should cover all Block tile!"); + static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock, + "Error! Warps should cover all Block tile!"); + + // Currently tested combinations (A, B, BQ) + // 1. fp8, fp8, fp32 -> f32 + // 2. bf8, bf8, fp32 -> f32 + // 3. i4, fp8, (fp8/fp32) -> f32 + // 4. i4, bf8, (fp8/fp32) -> f32 + static_assert( + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v) && + std::is_same_v); + + static constexpr index_t InterWaveSchedulingMacClusters = 1; + + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + static constexpr bool TransposeC = Problem::TransposeC; + }; + + public: + using Traits = GemmTraits_; + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + using QuantGroupSize = remove_cvref_t; + + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + static constexpr auto warp_size = get_warp_size(); + + using WG = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); // 128 / (1 * 16) = 8 + static constexpr index_t NIterPerWarp = + BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); // 128 / (4 * 16) = 2 + static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 128 / 16 = 8 + static constexpr auto MIter_2nd_last = + (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + + static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1 + static constexpr index_t QScalesPerWarpGemmRow = + integer_divide_ceil(WG::kK, QuantGroupSize::kK); + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8 + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + return BlockGemmQuantCommon:: + MakeCBlockTile(); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + ABlockTensor& a_warp_tensor, + BFlatBlockTensor& b_warp_tensor, + AQBlockTensor& aq_block_tensor, + BQBlockTensor& bq_block_tensor, + ABlockWindow& a_warp_windows) const + { + using CWarpDstr = typename WG::CWarpDstr; + using AccTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + statically_indexed_array, MIterPerWarp> + c_acc; + + auto zero_accumulators = [&] { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) { + c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; + }); // make sure WG::CWarpTensor exposes a clear/zero + }); + }); + }; + static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { + zero_accumulators(); + static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // warp GEMM + WG{}(c_acc(mIter)(nIter), + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); + }); + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows(number{})(number{})); + } + // barrier + // Could be deleted + if constexpr((mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + AQPickerCommon aq_picker(aq_block_tensor); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = + aq_picker.template cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; + }); + }); + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 16a0835b1d..313e449c7b 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -322,6 +322,7 @@ struct BQuantBlockUniversalGemmAsBsCr constexpr index_t reg_offset = nIter; auto pull_from_lane = (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; // cross lane ops uint32_t scale_reg_dword; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 4f79361037..004fb18e0b 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -280,12 +280,13 @@ struct QuantGemmKernel // Helper: Create Pre-shuffled Quantization Tensor Descriptor // =================================================================== template CK_TILE_DEVICE static auto - MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B) + MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B) { // Step 1: Calculate base BQ tensor dimensions // ---------------------------------------------------------- @@ -304,8 +305,9 @@ struct QuantGemmKernel // ---------------------------------------------------------- // Pad the X dimension to be a multiple of block_tile_size to ensure // each thread block can process complete tiles without edge cases - const auto block_tile_size = NPerBlock * KPerBlockBQ; - const auto bq_pad0_desc = transform_tensor_descriptor( + const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; + + const auto bq_pad0_desc = transform_tensor_descriptor( bq_desc, make_tuple(make_pass_through_transform(bq_y), make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), @@ -318,7 +320,7 @@ struct QuantGemmKernel // This separates the work into tiles that can be processed by // individual warps/waves const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; - const auto wave_tile_size = WarpTileN * KPerBlockBQ; + const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); const auto bq_unmerge_pad0_desc = transform_tensor_descriptor( @@ -813,12 +815,18 @@ struct QuantGemmKernel static_assert(std::is_same_v, "PreshuffleQuant with BQuantGrouped currently only supports " "ColumnMajor BQ layout"); + using QuantGroupSize = remove_cvref_t; return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, + GemmPipeline::NPerBlockBQ, GemmPipeline::NPerBlock, TilePartitioner::BlockGemmShape::WarpTile::at(I1), - GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); + GemmPipeline::GetVectorSizeBQ()>( + bq_ptr, + ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN), + QuantGroupSize::kN, + kargs.QK_B); } else { @@ -879,13 +887,38 @@ struct QuantGemmKernel if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); - constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = + constexpr auto block_n = + TilePartitioner::NPerBlock / + QuantGroupSize::kN; // Number of N-dimension quantization groups per block + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at( + I1); // Number of N-dimension elements per warp + constexpr auto warp_per_group = + (QuantGroupSize::kN < + warp_n) // Determine how many warps share the same scale in N-dimension + ? (warp_n / QuantGroupSize::kN) + : (QuantGroupSize::kN / warp_n); + constexpr auto bqk_per_block = + TilePartitioner::KPerBlock / + QuantGroupSize::kK; // Number of K-dimension quantization groups per block + constexpr auto + tile_window_width = // The pre-shuffled layout flattens warp_n × + // bqk_per_block scales per row, Padded up to warp_size + // to ensure coalesced memory access. ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_n / warp_n; - auto block_n_idx = i_n / block_n; + + // Adapts based on fine vs coarse quantization granularity: + // - Fine-grained (QuantGroupSize::kN < warp_n): + // Multiple quant groups per warp → fewer rows needed per block. + // height = block_n / warp_per_group + // + // - Coarse-grained (QuantGroupSize::kN >= warp_n): + // Each row represents one quant group. + // height = block_n + constexpr auto tile_window_height = + (QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n; + auto block_n_idx = + i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a + // block index. return make_tile_window( bq_tensor_view, @@ -1125,596 +1158,6 @@ struct QuantGemmKernel return true; } - template - CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - const AQDataType* aq_ptr, - const BQDataType* bq_ptr, - CDataType* c_ptr, - const QuantGemmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) - { - - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - const auto& a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - }(); - - const auto& aq_tensor_view = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) - { - static_assert(std::is_same_v); - const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; - const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ; - const auto aq_desc = - make_naive_tensor_descriptor(make_tuple(aq_y, aq_x), - make_tuple(aq_x, 1), - number{}, - number<1>{}); - - const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ; - const auto aq_pad0_desc = transform_tensor_descriptor( - aq_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1]; - const auto wave_tile_size = - GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; - const auto wave_tile_count_x = - ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size); - - const auto aq_unmerge_pad0_desc = transform_tensor_descriptor( - aq_pad0_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto aq_pad1_desc = transform_tensor_descriptor( - aq_unmerge_pad0_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_pass_through_transform(wave_tile_count_x), - make_right_pad_transform( - wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - - const auto pad_wave_size = - ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); - const auto aq_merge_pad1_desc = transform_tensor_descriptor( - aq_pad1_desc, - make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)), - make_pass_through_transform(pad_wave_size)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view(aq_ptr, aq_merge_pad1_desc); - } - else if constexpr((kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) && - !PreshuffleQuant) - { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.QK_A), - make_tuple(kargs.stride_AQ, 1), - number{}, - number<1>{}); - } - else // Column major AQ - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions - make_tuple(kargs.stride_AQ, 1), // Same stride pattern - number{}, - number<1>{}); - } - } - else if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, 0), // broadcasting over n - number<1>{}, - number<1>{}); - } - else - { - return nullptr; // TODO: use some other "empty" type for this - } - }(); - - const auto& b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.N), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - if constexpr(PreshuffleB) - { - index_t kFlatK = GemmPipeline::flatKPerWarp * - (splitk_batch_offset.splitted_k / - GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - return make_naive_tensor_view( - b_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - if constexpr(std::is_same_v) - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - else - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - } - }(); - - const auto& bq_tensor_view = [&]() { - if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_naive_tensor_view( - bq_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(0, 1), // broadcasting over m - number<1>{}, - number<1>{}); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - if constexpr(PreshuffleQuant) - { - static_assert(std::is_same_v, - "PreshuffleQuant with BQuantGrouped currently only supports " - "ColumnMajor BQ layout"); - - return MakePreshuffledQuantTensorView< - GemmPipeline::KPerBlockBQ, - GemmPipeline::NPerBlock, - TilePartitioner::BlockGemmShape::WarpTile::at(I1), - GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); - } - else - { - using QuantGroupSize = remove_cvref_t; - - if constexpr(std::is_same_v) - { - // For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN] - // Dimensions: [K/QuantGroupK, N/QuantGroupN] - // Strides: [N/QuantGroupN, 1] - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), - integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), - number{}, - number<1>{}); - } - else - { - static_assert(std::is_same_v); - // For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK] - // Dimensions: [N/QuantGroupN, K/QuantGroupK] - // Strides: [K/QuantGroupK, 1] - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), - integer_divide_ceil(kargs.K, QuantGroupSize::kK)), - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), - number{}, - number<1>{}); - } - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); - } - else - { - return nullptr; // TODO: use some other "empty" type for this - } - }(); - - // TODO: enable vector write for C in ColMajor - const auto& c_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - c_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - c_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple( - a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - // no padding - const auto& aq_pad_view = [&]() { return views.at(I1); }(); - - const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I2); - if constexpr(std::is_same_v) - { - if constexpr(std::is_same_v) - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - else - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - // no padding - const auto& bq_pad_view = [&]() { return views.at(I3); }(); - - // TODO vector write in for C in ColMajor - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I4); - if constexpr(std::is_same_v) - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - if constexpr(PreshuffleB) - { - - return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view); - } - else - { - return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - - const auto& a_pad_view = views.at(I0); - const auto& aq_pad_view = views.at(I1); - const auto& b_pad_view = views.at(I2); - const auto& bq_pad_view = views.at(I3); - const auto& c_pad_view = views.at(I4); - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& aq_block_window = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0); - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = - ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_m / warp_m; - auto block_m_idx = i_m / block_m; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {block_m_idx * tile_window_height, 0}); - } - else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) - { - using QuantGroupSize = remove_cvref_t; - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto block_m = TilePartitioner::MPerBlock; - if constexpr(std::is_same_v) - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } - else // Column major AQ - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, number{}), - {0, i_m}); - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto block_k = TilePartitioner::KPerBlock; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } - else if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return nullptr; // TODO: use some other "empty" type? - } - }(); - - const auto& b_block_window = [&]() { - if constexpr(PreshuffleB) - { - - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); - } - else - { - if constexpr(std::is_same_v) - { - if constexpr(std::is_same_v) - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - else - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - } - else - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {0, i_n}); - } - } - }(); - - const auto& bq_block_window = [&]() { - if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_tile_window(bq_pad_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - using QuantGroupSize = remove_cvref_t; - if constexpr(PreshuffleQuant) - { - static_assert(std::is_same_v); - constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); - constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = - ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_n / warp_n; - auto block_n_idx = i_n / block_n; - - return make_tile_window( - bq_pad_view, - make_tuple(number{}, number{}), - {block_n_idx * tile_window_height, 0}); - } - else - { - if constexpr(std::is_same_v) - { - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {0, i_n / QuantGroupSize::kN}); - } - else - { - static_assert(std::is_same_v); - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); - } - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); - } - else - { - return nullptr; // TODO: use some other "empty" type here - } - }(); - - auto c_block_window = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple( - a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window); - } - /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 06a80c8b55..c9e725f5fd 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -484,6 +484,17 @@ struct QuantGroupedGemmKernel tail_num, smem_ptr); } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + bq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr); + } else if constexpr(kQuantType == QuantType::RowColQuant || kQuantType == QuantType::TensorQuant) { @@ -499,7 +510,8 @@ struct QuantGroupedGemmKernel c_ptr, kargs, block_idx_m, block_idx_n); if constexpr(kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::BQuantGrouped) + kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } @@ -527,7 +539,8 @@ struct QuantGroupedGemmKernel c_ptr, kargs, block_idx_m, block_idx_n); if constexpr(kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::BQuantGrouped) + kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 39f0cbdbd3..a4bba6cf76 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -48,7 +48,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; - constexpr index_t VecLoadSize = GetVectorSizeBQ(); constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; using WarpTile = typename Problem::BlockGemmShape::WarpTile; @@ -68,7 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC BlockSize, NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), - VecLoadSize, + Problem::BQuantGroupSize::kN, + Problem::BQuantGroupSize::kK, BQLayout, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); @@ -83,6 +83,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC KPerBlockBQ, // Logical K dimension NPerBlockBQ, // Logical N dimension Problem::BQuantGroupSize::kN, + Problem::BQuantGroupSize::kK, BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index b43066cdc5..13d400d5fc 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -65,8 +65,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -300,9 +302,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), - 0) + (PreshuffleQuant) + ? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, NPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), + 0) : is_bq_row_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 0ec8942426..34f815ed27 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -192,6 +192,7 @@ template struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern @@ -208,31 +209,6 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding static_assert(num_warps == MWarps * NWarps * KWarps); static_assert(KWarps == 1); - /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) - /// - /// This function determines the optimal thread distribution pattern for loading and applying - /// quantization scales to the B matrix based on the quantization group size (NPerQ) relative - /// to warp dimensions. - /// - /// Three distinct distribution patterns are handled: - /// - /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): - /// - Multiple quantization groups exist within a single warp's N-dimension - /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) - /// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast - /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp - /// - /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): - /// - Each warp handles exactly one quantization scale - /// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN - /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 - /// - /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): - /// - Quantization group spans multiple warps - /// - All warps share the same scale value - /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale - /// - /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { // Preshuffle only supported for ColumnMajor currently @@ -241,22 +217,136 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding if constexpr(PreshuffleQuant) { - // ColumnMajor only for preshuffle - constexpr index_t X1 = warp_size; - constexpr index_t X0 = NPerTile / warp_size; - constexpr index_t Y1 = NWarps; - constexpr index_t Y0 = KPerTile / Y1; + // ============================================================================= + // PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION + // ============================================================================= + // For pre-shuffled quantization, the BQ scale tensor has been reorganized + // (pre-shuffled) to optimize memory access patterns during dequantization. + // + // Tile Dimensions: + // - K-axis (Y in encoding): Corresponds to the K-dimension iteration + // - N-axis (X in encoding): Flattened scale index combining N and K groups + // + // The encoding distributes work across threads such that each thread loads + // the correct pre-shuffled scale for its corresponding B-matrix elements. + // ============================================================================= + if constexpr(NPerQ <= WarpGemm::kN) + { + // ========================================================================= + // CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN) + // ========================================================================= + // Multiple quantization scales exist within a single warp's N-dimension. + // Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp. + // + // Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256 + // → 2 scales per warp in N, 2 K-groups per block + constexpr auto N1 = BlockGemmShape::kK / + KPerQ; // Number of K-dimension quantization groups per block, + // Each K-group of KPerQ elements shares the same scale. + constexpr auto N0 = + WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ + // <= WarpGemm::kN, each warp handles multiple scales. + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension + constexpr auto NR0 = + warp_size / + (N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization + constexpr auto K1 = NWarps; // Number of warps distributed along this dimension + constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile + constexpr auto KR = 1; // No replication in K-dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2>>, - tuple, sequence<1>>, - sequence<1, 2>, - sequence<0, 0>>{}); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2, 0>>, + tuple, sequence<1, 0, 2, 1, 3>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else if constexpr(NPerQ < WarpGemm::kN * NWarps) + { + // ========================================================================= + // CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN * + // NWarps) + // ========================================================================= + // Each warp handles exactly one quantization scale in N-dimension. + // Some warps share the same scale (KR > 1 creates warp grouping). + // + // Example: NPerQ=32, WarpGemm::kN=16, NWarps=4 + // → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups) + + constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale + constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales) + constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups + constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN) + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ) + constexpr auto NR0 = + warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<1, 0, 2, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + // ========================================================================= + // CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps) + // ========================================================================= + // The quantization group spans ALL warps in N-dimension. + // All warps share the same scale value for their N-tiles. + // + // Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 + // → 128 >= 16*4=64, so all 4 warps use the same scale + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups + constexpr auto N0 = 1; // Minimal (1) since scale is shared across N + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = 32; // Fixed broadcast size + constexpr auto NR0 = + warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<2, 0, 3, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } } else { + /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) + /// + /// This function determines the optimal thread distribution pattern for loading and + /// applying quantization scales to the B matrix based on the quantization group size + /// (NPerQ) relative to warp dimensions. + /// + /// Three distinct distribution patterns are handled: + /// + /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): + /// - Multiple quantization groups exist within a single warp's N-dimension + /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) + /// - Distribution includes explicit replication factor (XR = NPerQ) for scale + /// broadcast + /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp + /// + /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): + /// - Each warp handles exactly one quantization scale + /// - Scales are distributed across warps with replication factor XR = NPerQ / + /// WarpGemm::kN + /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 + /// + /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): + /// - Quantization group spans multiple warps + /// - All warps share the same scale value + /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale + /// + /// @return A static tile distribution encoding for the BQ scale tensor if constexpr(NPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp new file mode 100755 index 0000000000..80e41cad45 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -0,0 +1,120 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { + +struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelineAgBgCrPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() + { + using AQDataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK; + + return GetABQGlobalVectorLoadSize(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution() + { + return GemmAQuantPipelineAgBgCrDefaultPolicy::MakeAQDramTileDistribution(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + return GetABQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution(); + } + + // as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed; + // move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here + // temporarily + template + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + using BTypeToUse = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + + using WarpGemm = WarpGemmDispatcher; + + // TODO : Use a custom block policy for AsBrCr + using BlockGemmPolicy = + BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy; + return BlockGemmWeightPreshuffleABQuantARegBRegCReg{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp new file mode 100644 index 0000000000..0f3951ffcc --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -0,0 +1,611 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV2 +{ + using Base = WeightPreshufflePipelineAGmemBGmemCRegV2; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockWeightPreshuffle = remove_cvref_t< + decltype(PipelinePolicy::template GetBlockWeightPreshuffleBQuant())>; + + static constexpr auto config = + BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + using Base::kKPerBlock; + using Base::kMPerBlock; + using Base::kNPerBlock; + + using Base::KIterPerWarp; + using Base::MIterPerWarp; + using Base::NIterPerWarp; + + using Base::BlockSize; + + using Base::kPadK; + using Base::kPadM; + using Base::kPadN; + + using Base::I0; + using Base::I1; + using Base::I2; + + using Base::MWarp; + using Base::NWarp; + + using Base::KPerBlockPerIter; + using Base::MPerBlockPerIter; + + using Base::flatKPerWarp; + using Base::flatNPerWarp; + + using Base::m_preload; + + static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; + static constexpr index_t KPerBlockAQ = + integer_divide_ceil(BlockGemmShape::kK, AQuantGroupSize::kK); + static constexpr index_t KPerBlockBQ = + integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK); + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK); + static constexpr index_t GetVectorSizeAQ() + { + return PipelinePolicy::template GetVectorSizeAQ(); + } + static constexpr index_t GetVectorSizeBQ() + { + return PipelinePolicy::template GetVectorSizeBQ(); + } + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); + return concat('_', "bquant_pipeline_AgBgCrV2_preshuffleB", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeAQ(), GetVectorSizeBQ()), + concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName()); + // clang-format on + } + + template + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + // Estimated number of VMEM vector loads for A per block: + // total A bytes / (threads per block * vector width) + constexpr index_t Aload_inst = + (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + // Estimated number of VMEM vector loads for B per block: + // total B bytes / (threads per block * vector width) + constexpr index_t Bload_inst = + (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + + // Estimated number of VMEM loads for B's quant data (e.g. scales / zp). + // First ceil-divide by quant group size (how many elements share one scale), + // then by vector width to get an approximate number of vector loads. + constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( + ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), + BQuantGroupSize::kK * BQuantGroupSize::kK), + VectorLoadSize); + + // ToDo: Hardcoded, need to change in future. How many instruction emit per iteration + constexpr index_t kLdsInstCycle = 8; + // Total VMEM load instructions (A + B + quant data) + constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; + // Approximate number of LDS reads per block + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + // Approximate number of LDS writes per block + // (e.g., writing A from VMEM into LDS once per A load) + constexpr index_t ds_write_inst = Aload_inst; + // Number of MFMA instructions per wave for one block tile: + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + // How often (in MFMA units) we should insert DS (LDS) operations. + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // How often (in MFMA units) we should insert VMEM buffer loads. + // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // is assumed to cover at most 4 MFMA instructions. + constexpr index_t buffer_load_rep = + min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma + + static_for<0, nloop, 1>{}([&](auto) { + static_for<0, mfma_inst, 1>{}([&](auto i_inst) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + + // Insert LDS read/write groups periodically based on ds_rep. + // The % pattern staggers READ and WRITE so they don't collapse + // into the same cycle in the model. + if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read + } + if constexpr(ds_rep > 0 && i_inst % ds_rep == 1) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write + } + + if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0) + { + if constexpr(ds_write_inst > 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read + } + } + // Always mark some VALU work in the loop to reflect auxiliary scalar + // or vector ALU instructions that coexist with MFMA (Blockscale calculation). + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + + static constexpr bool PreshuffleB = Problem::PreshuffleB; + static constexpr auto TailNum = Problem::TailNum; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t m, + index_t n, + index_t num_loop, + void* p_smem) const + { + (void)m; + (void)n; + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/BQ Dram block window should have the same data type as appropriate " + "([A|B|BQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = std::is_same_v; + static_assert(!is_a_col_major, "A must be row major (col major not supported yet)"); + + constexpr bool is_bq_col_major = std::is_same_v; + static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + + constexpr bool is_b_row_major = std::is_same_v; + static_assert(!is_b_row_major, "B must be col major (row major not supported yet)"); + + const index_t iMWarp = get_warp_id() / NWarp; + // Double-Buffering (loop_count=2) for full load/compute overlap. + const index_t loop_count = 2; + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + ADataType* p_a_lds_ping = static_cast(p_smem); + ADataType* p_a_lds_pong = + reinterpret_cast(static_cast(p_smem) + smem_size); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + + // pingpong buffer for B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_ping; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_pong; + + auto aq_copy_dram_window = + make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), + aq_dram_block_window_tmp.get_window_lengths(), + aq_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeAQDramTileDistribution()); + // BQ DRAM window for load + auto bq_copy_dram_window = + make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + bq_dram_block_window_tmp.get_window_lengths(), + bq_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeBQDramTileDistribution()); + + // Prefetch A0 + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // Strictly not needed given type deduction, but helps with readability + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution()); + using BQBlockTile = + decltype(make_static_distributed_tensor(BQBlockTileDistr{})); + + // Load tile 0 for BQ data directly into registers for block tile + AQBlockTile aq_block_tile, aq_block_tile_2; + BQBlockTile bq_block_tile, bq_block_tile_2; + aq_block_tile = load_tile(aq_copy_dram_window); + bq_block_tile = load_tile(bq_copy_dram_window); + // move BQ to tile 1 + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + // Prefill A0 + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + // preload A00,A10 from lds + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + __builtin_amdgcn_sched_barrier(0); + + // MAIN LOOP + index_t iCounter = (num_loop - 1) / loop_count; + + while(iCounter > 0) + { + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + aq_block_tile_2 = load_tile(aq_copy_dram_window); + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + bq_block_tile_2 = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // Next K + + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + aq_block_tile = load_tile(aq_copy_dram_window); + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i+1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + aq_block_tile_2, + bq_block_tile_2, + a_warp_windows_pong); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + iCounter--; + HotLoopScheduler(); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + aq_block_tile_2 = load_tile(aq_copy_dram_window); + bq_block_tile_2 = load_tile(bq_copy_dram_window); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // GEMM loopK-1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + aq_block_tile_2, + bq_block_tile_2, + a_warp_windows_pong); + HotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + Base::LastHotLoopScheduler(); + } + + return c_block_tile; + } + + // Replace lines 485-526 with a single optimized operator: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem, + index_t m = 0, + index_t n = 0) const // Default value for non-preshuffle case + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + aq_dram_block_window_tmp, + bq_dram_block_window_tmp, + m, + n, + num_loop, + p_smem); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + TailNumber tail_number, + void* p_smem, + index_t n = 0) const + { + const auto RunPipeline = [&](auto bool_val, auto tail_num_) { + (void)bool_val; // Suppress unused parameter warning + constexpr auto tail_num = tail_num_.value; + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + aq_dram_block_window_tmp, + bq_dram_block_window_tmp, + n, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, true, tail_number); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 43f37ec4d8..e4de7e4211 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -71,6 +71,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; + static constexpr index_t NPerBlockBQ = + integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN); static constexpr index_t KPerBlockBQ = integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = @@ -352,8 +354,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else @@ -427,8 +431,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else @@ -462,8 +468,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 197c9d6e1d..93cd7fa063 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(grouped_gemm_multi_d) add_subdirectory(grouped_gemm_quant) +add_subdirectory(grouped_gemm_abquant) add_subdirectory(gemm_multi_d) add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_streamk) diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index f89aea1c17..2dad8be205 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -39,6 +39,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle + test_gemm_quant_abquant_preshuffle_2d.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # AQuant tests add_gtest_executable(test_tile_gemm_quant_aquant_prefill test_gemm_quant_aquant_prefill.cpp ) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp new file mode 100644 index 0000000000..793c9bd1df --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantPreshuffleBTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 6fb1b77fa8..3798cc4443 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -894,10 +894,10 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase; - using BaseGemmPipeline = - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + using BaseGemmPipeline = std::conditional_t< + PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); @@ -926,8 +926,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase; using GemmPipeline = - std::conditional_t, + std::conditional_t, ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< diff --git a/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt b/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt new file mode 100644 index 0000000000..e735aa8e9a --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x1x128 test_grouped_gemm_abquant_1x1x128.cpp) + target_compile_options(test_ck_tile_grouped_gemm_abquant_1x1x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x128x128 test_grouped_gemm_abquant_1x128x128.cpp) + target_compile_options(test_ck_tile_grouped_gemm_abquant_1x128x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() + diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp new file mode 100644 index 0000000000..06b0068cb7 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_abquant_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; + +// AQuant group size is fixed at 1x1x128 +using AQuantGroupSize = ck_tile::QuantGroupShape>; +// BQuant group size: 1x128x128 +using BQuantGroupSize_1x128x128 = ck_tile::QuantGroupShape>; + +// clang-format off +using KernelTypes_ABQuant_1x128x128 = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent + + // FP8 variants + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + + // BF8 variants + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x128x128, KernelTypes_ABQuant_1x128x128); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x128x128 +#include "test_grouped_gemm_abquant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp new file mode 100644 index 0000000000..7704eda169 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_abquant_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; + +// AQuant group size is fixed at 1x1x128 +using AQuantGroupSize = ck_tile::QuantGroupShape>; +// BQuant group size: 1x1x128 +using BQuantGroupSize_1x1x128 = ck_tile::QuantGroupShape>; + +// clang-format off +using KernelTypes_ABQuant_1x1x128 = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent + + // FP8 variants + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + + // BF8 variants + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x1x128, KernelTypes_ABQuant_1x1x128); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x1x128 +#include "test_grouped_gemm_abquant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc new file mode 100644 index 0000000000..48574ab977 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_CLASS_NAME, Basic) +{ + const int group_count = 6; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} + +// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop +// Using 256x256x128 to match the test kernel's tile size (M_Tile=128, N_Tile=128, K_Tile=128) +TYPED_TEST(TEST_CLASS_NAME, SmallUniform) +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(256); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} + +TYPED_TEST(TEST_CLASS_NAME, OddTail) +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(128); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp new file mode 100644 index 0000000000..c7ed6f5472 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp @@ -0,0 +1,530 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" + +template +class TestCkTileGroupedGemmABQuant : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using AQDataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using BQDataType = std::tuple_element_t<6, Tuple>; + using AccDataType = std::tuple_element_t<7, Tuple>; + using CDataType = std::tuple_element_t<8, Tuple>; + using AQuantGroupSize = std::tuple_element_t<9, Tuple>; + using BQuantGroupSize = std::tuple_element_t<10, Tuple>; + static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using AQLayout = Row; + using BQLayout = Col; + + static constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped; + + struct GemmConfig + { + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(ADataType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr bool PreshuffleB = false; + static constexpr bool TransposeC = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + static constexpr bool IsPersistent = Persistent; + }; + + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + + std::size_t get_workspace_size(const std::vector& gemm_descs) + { + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); + } + + template + static constexpr inline auto is_row_major(Layout layout_) + { + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; + } + + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeType = + std::conditional_t; + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } + + template + float invoke_grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile:: + TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * Config::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = Config::Scheduler; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Config::M_Warp, + Config::N_Warp, + Config::M_Warp_Tile, + Config::N_Warp_Tile, + Config::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } + + template + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) + { + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Config::M_Warp, + Config::N_Warp, + Config::M_Warp_Tile, + Config::N_Warp_Tile, + Config::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + std::vector& stride_As, + std::vector& stride_Bs, + std::vector& stride_Cs, + std::vector& stride_AQs, + std::vector& stride_BQs, + const int group_count = 8) + { + ck_tile::index_t AQK, BQK; + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + AQK = K / AQuantGroupSize::kK; + BQK = K / BQuantGroupSize::kK; + + if(K % AQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode"); + } + if(K % BQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode"); + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{})); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{})); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout{})); + stride_BQs[i] = + ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout{})); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{})))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(BLayout{})))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(BQLayout{})))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc + << " c_m_n: " << c_m_n_tensors[i].mDesc << " aq: " << aq_tensors[i].mDesc + << " bq: " << bq_tensors[i].mDesc << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back(std::make_unique( + aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back(std::make_unique( + bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + 1, // k_batch + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + + if constexpr(Persistent) + { + std::vector kargs; + for(const auto& arg : gemm_descs) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + ck_tile::hip_check_error( + hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr); + } + else + { + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + invoke_grouped_gemm_abquant(gemm_descs, stream, kargs_ptr); + } + + // Copy results back to host for validation + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value); + pass &= + ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + EXPECT_TRUE(pass); + } +}; + +// Aliases for split test files +template +using TestCkTileGroupedGemmABQuant_1x1x128 = TestCkTileGroupedGemmABQuant; + +template +using TestCkTileGroupedGemmABQuant_1x128x128 = TestCkTileGroupedGemmABQuant;