From 27de2f8fc82d7e1d5d7253f14244bc390a942e21 Mon Sep 17 00:00:00 2001 From: kyle-256 Date: Wed, 7 Jan 2026 04:36:04 +0800 Subject: [PATCH] [CKTILE] Support A/B Quantization in Blockscale Grouped Gemm (#3452) * update grouped_gemm blockwise kernel * update config * update kernel * update examples * remove test code for now * sync test files with origin/develop * update example * fix code lint * fix code-lint * update test code * run clang format * run pre-commit * update api [ROCm/composable_kernel commit: 76696ace4460a5bcf79d9a75a97b30c76507e284] --- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 3 +- .../17_grouped_gemm/abquant_grouped_gemm.cpp | 278 ++++++++ .../17_grouped_gemm/abquant_grouped_gemm.hpp | 171 +++++ .../run_grouped_gemm_abquant_example.inc | 604 ++++++++++++++++++ .../kernel/grouped_gemm_quant_kernel.hpp | 17 +- test/ck_tile/CMakeLists.txt | 1 + .../grouped_gemm_abquant/CMakeLists.txt | 16 + .../test_grouped_gemm_abquant_1x128x128.cpp | 47 ++ .../test_grouped_gemm_abquant_1x1x128.cpp | 47 ++ .../test_grouped_gemm_abquant_ut_cases.inc | 87 +++ .../test_grouped_gemm_abquant_util.hpp | 530 +++++++++++++++ 11 files changed, 1798 insertions(+), 3 deletions(-) create mode 100644 example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp create mode 100644 example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp create mode 100644 example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc create mode 100644 test/ck_tile/grouped_gemm_abquant/CMakeLists.txt create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc create mode 100644 test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 9b51af22fe..0f0a0d8ba7 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -14,7 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") quant_grouped_gemm_bf8_rowcol.cpp quant_grouped_gemm_bf8_tensor.cpp ) - + add_executable(tile_example_abquant_grouped_gemm abquant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp) add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) @@ -25,4 +25,5 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_abquant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp new file mode 100644 index 0000000000..84da1e26da --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp @@ -0,0 +1,278 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/host.hpp" +#include "abquant_grouped_gemm.hpp" + +// Non-persistent grouped gemm for ABQuant +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +// Persistent grouped gemm tileloop for ABQuant +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); +} + +#include "run_grouped_gemm_abquant_example.inc" + +int main(int argc, char* argv[]) +{ + int result1 = run_abquant_grouped_gemm_example(argc, argv); + return result1; +} diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp new file mode 100644 index 0000000000..da8bd5514c --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp @@ -0,0 +1,171 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); +}; + +template +struct GemmQuantConfig; + +// ABQuant specialization for GemmQuantConfig +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert( + "stride_As", + "", + "Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs + // can be set to zero if + // Ms/Ns/Ks is not empty + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") + .insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.") + .insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") + .insert("c_layout", "R", "C tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.") + .insert("bquant_group_size", "1x1x128", "BQuant group size. 1x1x128 (default) or 1x128x128") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "abquant_grouped_gemm.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); +} + +// Forward declaration of the non-persistent version +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr); + +// Forward declaration of the tileloop version for persistent kernels +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc new file mode 100644 index 0000000000..bc5167439d --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc @@ -0,0 +1,604 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_abquant_gemm(int n_warmup, + int n_repeat, + int group_count, + const std::vector& args) +{ + // Workspace memory allocated to hold the gemm descriptions. + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(args)); + + float ave_time = 0; + + if constexpr(!GemmConfig::Persistent) + { + ave_time = grouped_gemm_abquant( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); + } + + return ave_time; +} + +template +int run_abquant_grouped_gemm_example_with_layouts( + int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + const BQLayout bq_layout = BQLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + + auto [result, arg_parser] = create_args(argc, argv); + + auto valid_input_data = [&](int group_count, const auto&... args) { + return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); + }; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + const int init_method = arg_parser.get_int("init"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector AQs; // dimension of AQ tensor is calculated from A tensor + std::vector BQs; // dimension of BQ tensor is calculated from B tensor + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); + std::vector stride_AQs = arg_parser.get_int_vec("stride_AQs"); + std::vector stride_BQs = arg_parser.get_int_vec("stride_BQs"); + + ck_tile::index_t AQK, BQK; + + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs)) + { + std::cout << "Please check the input data. Default values will be used." << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + stride_AQs.clear(); + stride_BQs.clear(); + + for(int i = 0; i < group_count; i++) + { + + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + // Let get_default_stride calculate based on layout + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + } + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + // For ABQuantGrouped, both A and B need quantization + static_assert(QuantMode == ck_tile::QuantType::ABQuantGrouped, + "This file only supports ABQuantGrouped mode"); + + AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / AQuantGroupSize + BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / BQuantGroupSize + if(K % AQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode"); + } + if(K % BQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode"); + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout)); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc + << " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl; + + if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(bq_tensors[i]); + } + else + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + } + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back( + std::make_unique(aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back( + std::make_unique(bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + kbatch, + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + float ave_time = invoke_abquant_gemm(warmup, repeat, group_count, gemm_descs); + + std::string op_name = "ABQuant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K; + + num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K + + sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N + + sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + if(validate) + { + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + // Reference implementation for ABQuantGrouped + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); + pass &= + ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + + if(arg_parser.get_int("json") == 1) + { + dump_grouped_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + group_count, + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass; +} + +template +int run_abquant_grouped_gemm_example_prec_type_with_bquant( + std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + using AQDataType = typename Types::AccDataType; + using BQDataType = typename Types::AccDataType; + using AQuantGroupSize = ck_tile::QuantGroupShape>; + + constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped; + + if(a_layout == "R" && b_layout == "C" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Col{}, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + +template +int run_abquant_grouped_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + std::string c_layout, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(bquant_group_size == "1x1x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else if(bquant_group_size == "1x128x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported BQuantGroupSize! Use 1x1x128 or 1x128x128."); + } +} + +template +int run_abquant_gemm_example_persistency(std::string a_layout, + std::string b_layout, + std::string c_layout, + bool persistent, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(persistent) + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } + else + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } +} + +int run_abquant_grouped_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string c_layout = arg_parser.get_str("c_layout"); + const std::string data_type = arg_parser.get_str("prec"); + bool persistent = arg_parser.get_bool("persistent"); + const std::string bquant_group_size = arg_parser.get_str("bquant_group_size"); + + if(data_type == "fp8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else if(data_type == "bf8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type configuration."); + } +} diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 06a80c8b55..c9e725f5fd 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -484,6 +484,17 @@ struct QuantGroupedGemmKernel tail_num, smem_ptr); } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + bq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr); + } else if constexpr(kQuantType == QuantType::RowColQuant || kQuantType == QuantType::TensorQuant) { @@ -499,7 +510,8 @@ struct QuantGroupedGemmKernel c_ptr, kargs, block_idx_m, block_idx_n); if constexpr(kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::BQuantGrouped) + kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } @@ -527,7 +539,8 @@ struct QuantGroupedGemmKernel c_ptr, kargs, block_idx_m, block_idx_n); if constexpr(kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::BQuantGrouped) + kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) { EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 197c9d6e1d..93cd7fa063 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(grouped_gemm_multi_d) add_subdirectory(grouped_gemm_quant) +add_subdirectory(grouped_gemm_abquant) add_subdirectory(gemm_multi_d) add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_streamk) diff --git a/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt b/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt new file mode 100644 index 0000000000..e735aa8e9a --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x1x128 test_grouped_gemm_abquant_1x1x128.cpp) + target_compile_options(test_ck_tile_grouped_gemm_abquant_1x1x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x128x128 test_grouped_gemm_abquant_1x128x128.cpp) + target_compile_options(test_ck_tile_grouped_gemm_abquant_1x128x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() + diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp new file mode 100644 index 0000000000..06b0068cb7 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_abquant_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; + +// AQuant group size is fixed at 1x1x128 +using AQuantGroupSize = ck_tile::QuantGroupShape>; +// BQuant group size: 1x128x128 +using BQuantGroupSize_1x128x128 = ck_tile::QuantGroupShape>; + +// clang-format off +using KernelTypes_ABQuant_1x128x128 = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent + + // FP8 variants + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + + // BF8 variants + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x128x128, KernelTypes_ABQuant_1x128x128); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x128x128 +#include "test_grouped_gemm_abquant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp new file mode 100644 index 0000000000..7704eda169 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_abquant_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; + +// AQuant group size is fixed at 1x1x128 +using AQuantGroupSize = ck_tile::QuantGroupShape>; +// BQuant group size: 1x1x128 +using BQuantGroupSize_1x1x128 = ck_tile::QuantGroupShape>; + +// clang-format off +using KernelTypes_ABQuant_1x1x128 = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent + + // FP8 variants + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + + // BF8 variants + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x1x128, KernelTypes_ABQuant_1x1x128); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x1x128 +#include "test_grouped_gemm_abquant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc new file mode 100644 index 0000000000..48574ab977 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_CLASS_NAME, Basic) +{ + const int group_count = 6; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} + +// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop +// Using 256x256x128 to match the test kernel's tile size (M_Tile=128, N_Tile=128, K_Tile=128) +TYPED_TEST(TEST_CLASS_NAME, SmallUniform) +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(256); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} + +TYPED_TEST(TEST_CLASS_NAME, OddTail) +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(128); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp new file mode 100644 index 0000000000..c7ed6f5472 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp @@ -0,0 +1,530 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" + +template +class TestCkTileGroupedGemmABQuant : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using AQDataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using BQDataType = std::tuple_element_t<6, Tuple>; + using AccDataType = std::tuple_element_t<7, Tuple>; + using CDataType = std::tuple_element_t<8, Tuple>; + using AQuantGroupSize = std::tuple_element_t<9, Tuple>; + using BQuantGroupSize = std::tuple_element_t<10, Tuple>; + static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using AQLayout = Row; + using BQLayout = Col; + + static constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped; + + struct GemmConfig + { + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(ADataType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr bool PreshuffleB = false; + static constexpr bool TransposeC = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + static constexpr bool IsPersistent = Persistent; + }; + + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + + std::size_t get_workspace_size(const std::vector& gemm_descs) + { + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); + } + + template + static constexpr inline auto is_row_major(Layout layout_) + { + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; + } + + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeType = + std::conditional_t; + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } + + template + float invoke_grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile:: + TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * Config::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = Config::Scheduler; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Config::M_Warp, + Config::N_Warp, + Config::M_Warp_Tile, + Config::N_Warp_Tile, + Config::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } + + template + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) + { + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Config::M_Warp, + Config::N_Warp, + Config::M_Warp_Tile, + Config::N_Warp_Tile, + Config::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + std::vector& stride_As, + std::vector& stride_Bs, + std::vector& stride_Cs, + std::vector& stride_AQs, + std::vector& stride_BQs, + const int group_count = 8) + { + ck_tile::index_t AQK, BQK; + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + AQK = K / AQuantGroupSize::kK; + BQK = K / BQuantGroupSize::kK; + + if(K % AQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode"); + } + if(K % BQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode"); + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{})); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{})); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout{})); + stride_BQs[i] = + ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout{})); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{})))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(BLayout{})))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(BQLayout{})))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc + << " c_m_n: " << c_m_n_tensors[i].mDesc << " aq: " << aq_tensors[i].mDesc + << " bq: " << bq_tensors[i].mDesc << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back(std::make_unique( + aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back(std::make_unique( + bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + 1, // k_batch + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + + if constexpr(Persistent) + { + std::vector kargs; + for(const auto& arg : gemm_descs) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + ck_tile::hip_check_error( + hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr); + } + else + { + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + invoke_grouped_gemm_abquant(gemm_descs, stream, kargs_ptr); + } + + // Copy results back to host for validation + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value); + pass &= + ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + EXPECT_TRUE(pass); + } +}; + +// Aliases for split test files +template +using TestCkTileGroupedGemmABQuant_1x1x128 = TestCkTileGroupedGemmABQuant; + +template +using TestCkTileGroupedGemmABQuant_1x128x128 = TestCkTileGroupedGemmABQuant;