diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 7a9b5afaa2..24f64994cf 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -114,16 +114,16 @@ template struct GemmConfigComputeV3 : public GemmConfigBase { // Compute V3 only support Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt new file mode 100644 index 0000000000..bdcb6f50bd --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -0,0 +1,13 @@ +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") + add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp) + target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile quant gemm tests for current target") +endif() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md new file mode 100644 index 0000000000..742a88dee7 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -0,0 +1,35 @@ +# GEMM Matrix Multiplication + +This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The aquant pipeline method on the gemm calculation +make tile_example_gemm_aquant_basic -j +``` +This will result in an executable `build/bin/tile_example_gemm_aquant_basic` + +## example +``` +args: + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) + -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp new file mode 100644 index 0000000000..a1ed3c4920 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" + +template +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) +{ + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + static_assert(std::is_same_v); + + constexpr ck_tile::index_t M_Tile = 16; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 256; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 32; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmAQuantTraits; + + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + constexpr bool transposed_warp_gemm = false; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using CodegenPipelineProblem = + ck_tile::GemmAQuantPipelineProblem; + using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; + using Kernel = + ck_tile::AQuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(args.k_batch != 1) + { + throw std::runtime_error("split-k is not supported yet!"); + } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +#include "run_gemm_aquant_example.inc" + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for A."); + } + + return 0; +} + +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(data_type == "fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4fp8") + { + using TypeConfig = decltype( + GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4bf8") + { + using TypeConfig = decltype( + GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32bf8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp new file mode 100644 index 0000000000..35e80ddb89 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -0,0 +1,675 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_group_quant.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE 5 + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(__gfx950__) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(__gfx950__) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct GemmConfigPreshufle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshufle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + +template +struct GemmQuantTypeConfig +{ + using ADataType = ADataType_; + using QDataType = QDataType_; + using BDataType = BDataType_; + using AccDataType = float; + using CDataType = CDataType_; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using QDataType = float; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("aq_layout", "R", "Aq tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent") + .insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc new file mode 100644 index 0000000000..9bdef9755b --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_aquant_example.inc @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& aq_m_aqk_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t AQK, + ck_tile::index_t stride_A, + ck_tile::index_t stride_AQ, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::AQuantGemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.QK = AQK; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + args.stride_AQ = stride_AQ; + + float ave_time = gemm_calc_aquant( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK + + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B + << " StrideC =" << stride_C << " A_Layout =" << ALayout::name + << " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name + << " A_Type = " << DataTypeTraits::name + << " AQ_Type = " << DataTypeTraits::name + << " B_Type = " << DataTypeTraits::name + << " Acc_Type = " << DataTypeTraits::name + << " C_Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +template +int run_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using ADataType = typename TypeConfig::ADataType; + using AQDataType = typename TypeConfig::QDataType; + using BDataType = typename TypeConfig::BDataType; + using AccDataType = typename TypeConfig::AccDataType; + using CDataType = typename TypeConfig::CDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + if(K % QuantGroupSize != 0) + { + throw std::runtime_error("K must be aligned with QuantGroupSize"); + } + + ck_tile::index_t AQK = K / QuantGroupSize; + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor aq_m_aqk( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution fill_seed(0, 500); + + if(init_method == 0) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else if(init_method == 1) + { + std::cout << "Monotonic initialization is not supported." << std::endl; + return 0; + } + else if(init_method == 2) + { + ck_tile::FillConstant{static_cast(0x22)}(a_m_k); + ck_tile::FillConstant{static_cast(0.5f)}(aq_m_aqk); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + } + else + { + a_m_k.SetZero(); + aq_m_aqk.SetZero(); + b_k_n.SetZero(); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm(a_m_k_dev_buf, + aq_m_aqk_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + AQK, + stride_A, + stride_AQ, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + if(!pass) + { + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl; + return false; + } + + return pass; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 8989060842..db5cc71888 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -23,3 +23,4 @@ add_subdirectory(20_grouped_convolution) add_subdirectory(35_batched_transpose) add_subdirectory(36_copy) add_subdirectory(37_transpose) +add_subdirectory(38_block_scale_gemm) diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 541093e337..ba8b87a9b8 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -116,6 +116,24 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t& x) +{ + uint8_t x_u8 = ck_tile::bit_cast(x); + + float x_l = ((x_u8 & 0x0f) >> 0); + float x_h = ((x_u8 & 0xf0) >> 4); + + x_l = x_l > 7 ? x_l - 16 : x_l; + x_h = x_l > 7 ? x_l - 16 : x_l; + +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + fp32x2_t res = {x_h, x_l}; +#elif + fp32x2_t res = {x_l, x_h}; +#endif + return res; +} + CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 4a359e031f..9b31a7889d 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -92,6 +93,60 @@ struct FillUniformDistribution } }; +template <> +struct FillUniformDistribution +{ + float a_{-8.f}; // same type as primary template so that + // `FillUniformDistribution{-5.0f, 5.0f}` works for all types + float b_{7.f}; + std::optional seed_{11939}; + template + void operator()(ForwardIter first, ForwardIter last) const + { + if(a_ < -8.0f || b_ > 7.0f) + { + throw std::runtime_error( + "a_ or b_ of FillUniformDistribution is out of range."); + } + + int min_value = static_cast(a_); + int max_value = static_cast(b_); + constexpr auto int4_array = std::array{0x88, + 0x99, + 0xaa, + 0xbb, + 0xcc, + 0xdd, + 0xee, + 0xff, + 0x00, + 0x11, + 0x22, + 0x33, + 0x44, + 0x55, + 0x66, + 0x77}; + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_int_distribution dis(0, max_value - min_value + 1); + while(first != last) + { + int randomInt = dis(gen); + *first = int4_array[randomInt + (min_value + 8)]; + ++first; + } + } + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + namespace impl { // clang-format off diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index c88deaec01..70ca44170e 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -11,6 +11,110 @@ namespace ck_tile { +template +CK_TILE_HOST void reference_gemm_quant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mn = [&](auto m, auto n) { + AccDataType v_acc = 0, v_block_acc = 0; + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v || + std::is_same_v); + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + } + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else if constexpr(std::is_same_v) + { + v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n))); + } + else + { + v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + } + v_block_acc += v_a * v_b; + + // Apply group dequant scale + if((k + 1) % QuantGroupSize == 0) + { + float scale = 0.f; + index_t outer_dim = (aquant) ? m : k / QuantGroupSize; + index_t inner_dim = (aquant) ? k / QuantGroupSize : n; + + if constexpr(std::is_same_v) + { + scale = q(outer_dim, inner_dim); + } + else if constexpr(std::is_same_v) + { + scale = fp8_to_float_raw(q(outer_dim, inner_dim)); + } + else if constexpr(std::is_same_v) + { + scale = bf8_to_float_raw(q(outer_dim, inner_dim)); + } + else + { + static_assert(false, "Unexpected Q datatype."); + } + v_block_acc *= scale; + v_acc += v_block_acc; + v_block_acc = 0; + } + } + + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + }; + + make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); + std::cout << std::endl; +} + template (a), src_hi; + uint32_t fp8x4_lo, fp8x4_hi; + float tmp_0, tmp_1; + + asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n" + "v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n" + "v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n" + "v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n" + "v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n" + : [v_tmp_0] "+v"(tmp_0), + [v_tmp_1] "+v"(tmp_1), + [v_hi_src] "+v"(src_hi), + [v_dst_lo] "+v"(fp8x4_lo), + [v_dst_hi] "+v"(fp8x4_hi), + [v_src] "+v"(src) + :); + + return bit_cast(((static_cast(fp8x4_hi) << 32) | fp8x4_lo)); +} + +CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src) +{ + float res; + asm volatile("v_cvt_f32_fp8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src)); + return res; +} + +CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src) +{ + float res; + asm volatile("v_cvt_f32_bf8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src)); + return res; +} + +CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a) +{ + uint32_t src = static_cast(a), src_hi; + uint32_t bf8x4_lo, bf8x4_hi; + float tmp_0, tmp_1; + + asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n" + "v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n" + "v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n" + "v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n" + + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n" + "v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n" + : [v_tmp_0] "+v"(tmp_0), + [v_tmp_1] "+v"(tmp_1), + [v_hi_src] "+v"(src_hi), + [v_dst_lo] "+v"(bf8x4_lo), + [v_dst_hi] "+v"(bf8x4_hi), + [v_src] "+v"(src) + :); + + return bit_cast(((static_cast(bf8x4_hi) << 32) | bf8x4_lo)); +} + struct PassThroughPack8 { template @@ -126,6 +206,16 @@ struct PassThroughPack8 y.lo = i4_to_bhalf4(bit_cast(x)); y.hi = i4_to_bhalf4(bit_cast(x) >> 16); } + + CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const + { + y = amd_assembly_i4_to_fp8x8(bit_cast(x)); + } + + CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const + { + y = amd_assembly_i4_to_bf8x8(bit_cast(x)); + } constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index bf58544259..7ae63e17a7 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -69,6 +69,8 @@ struct CShuffleEpilogue using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; + using ATypeToUse = + std::conditional_t, BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; @@ -201,7 +203,7 @@ struct CShuffleEpilogue static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); - using WG = WarpGemmMfmaDispatcher [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); return concat('_', "pipeline_AgBgCrCompV3", - concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize), - concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', WaveNumM, WaveNumN), concat('x', kPadM, kPadN, kPadK)); // clang-format on } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 27a81ff090..97fab489ab 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -37,6 +37,7 @@ struct WarpGemmAtrributeMfma static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK; static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t kCMLane = Impl::kCMLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index f9d50ed35e..38fd0d408b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -11,9 +11,10 @@ struct WarpGemmImpl { using WarpGemmAttribute = remove_cvref_t; - static constexpr index_t kM = WarpGemmAttribute::kM; - static constexpr index_t kN = WarpGemmAttribute::kN; - static constexpr index_t kK = WarpGemmAttribute::kK; + static constexpr index_t kM = WarpGemmAttribute::kM; + static constexpr index_t kN = WarpGemmAttribute::kN; + static constexpr index_t kK = WarpGemmAttribute::kK; + static constexpr index_t kCMLane = WarpGemmAttribute::kCMLane; /// @brief The number of elements in K dimension processed by single thread in wavefront. /// /// @note Note that WarpGemm may run MFMA instruction multiple times (on different K). diff --git a/include/ck_tile/ops/gemm_group_quant.hpp b/include/ck_tile/ops/gemm_group_quant.hpp new file mode 100644 index 0000000000..0041c658b4 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" +#include "ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp" diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp new file mode 100644 index 0000000000..c1ff6a356e --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -0,0 +1,489 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/elementwise.hpp" + +namespace ck_tile { + +template +struct BlockGemmQuantBase +{ + using AQDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + static constexpr index_t UnaryOpSize = UnaryOpSize_; + template + CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale) + { + float scale_reg_f = 0.f; + if constexpr(std::is_same_v) + { + scale_reg_f = + ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + } + else if constexpr(std::is_same_v) + { + scale_reg_f = + ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + } + else if constexpr(std::is_same_v) + { + scale_reg_f = ck_tile::bit_cast(scale); + } + else + { + static_assert(false, "AQDataType must be float, fp8_t or bf8_t."); + } + return scale_reg_f; + } + + template + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, + const WarpWindow& warp_window) + { + const element_wise::PassThroughPack8 elementwise_op{}; + + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); + + using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), + in_dstr_tensors.get_thread_buffer().template get_as()[i]); + }); + } +}; + +// A is block window on shared memory +// AQ (scale tensor) is block distributed tensor. +// Consecutive kQuantGroupSize elements of A are quantized with a separate scale. +// B is block window on shared memory +// C is block distributed tensor +template +struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase +{ + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr auto Scheduler = Problem::Scheduler; + + // Threadblock GEMM tile size + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t AQPerBlock = KPerBlock / kQuantGroupSize; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + // number of warps along M and N for threadblock's GEMM problem size + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + using I0 = number<0>; + using I1 = number<1>; + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), + "Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"); + static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), + "Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"); + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), + "Error! WarpGemm's M is not consisten with BlockGemmShape!"); + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), + "Error! WarpGemm's N is not consisten with BlockGemmShape!"); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr index_t QScalesPerBlockRow = + (KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize; + static constexpr index_t QScalesPerWarpGemmRow = + (WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize; + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + static_assert(kQuantGroupSize % WarpGemm::kK == 0, + "Error! WarpGemm::kK should be a multiple of kQuantGroupSize"); + static_assert(QScalesPerWarpGemmRow == 1, + "Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK"); + static_assert(KIterPerWarp % QScalesPerBlockRow == 0, + "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); + + static_assert(KPerBlock / kQuantGroupSize > 0, + "Error! Each row of blockgemm should have a separate scale"); + + static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, + "Error! Warps should cover all Block tile!"); + static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock, + "Error! Warps should cover all Block tile!"); + + // Currently tested combinations (A, AQ, B) + // 1. fp8, fp32, fp8 -> f32 + // 2. bf8, fp32, bf8 -> f32 + // 3. i4, (fp8/fp32) fp8 -> f32 + // 4. i4, (fp8/fp32) bf8 -> f32 + static_assert( + (std::is_same_v || std::is_same_v || + std::is_same_v< + ADataType, + bf8_t>)&&(std::is_same_v || + std::is_same_v< + BDataType, + bf8_t>)&&(std::is_same_v || + std::is_same_v || + std::is_same_v< + AQDataType, + ck_tile::bf8_t>)&&(std::is_same_v || + std::is_same_v)&&std:: + is_same_v); + + static constexpr index_t InterWaveSchedulingMacClusters = 1; + + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + }; + + public: + using Traits = GemmTraits_; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using Base = BlockGemmQuantBase; + + using WarpGemm = remove_cvref_t; + + static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; + static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; + static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; + + static constexpr index_t MWarp = Traits::MWarp; + static constexpr index_t NWarp = Traits::NWarp; + + static constexpr auto Scheduler = Traits::Scheduler; + static constexpr uint8_t kA_cvt_scale = std::is_same_v ? 16 : 1; + static constexpr uint8_t kB_cvt_scale = std::is_same_v ? 16 : 1; + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static_assert(std::is_same_v); + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using I0 = number<0>; + using I1 = number<1>; + + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + private: + template + struct BlockGemmImpl + { + }; + + template + struct BlockGemmImpl + { + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + BLdsTile b_warp_tile_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + if constexpr(std::is_same_v) + { + static_assert(std::is_same_v || + std::is_same_v); + Base::load_interleaved_pk_type(a_warp_tile_, a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + if constexpr(std::is_same_v) + { + static_assert(std::is_same_v || + std::is_same_v); + Base::load_interleaved_pk_type(b_warp_tile_, b_block_window); + } + else + { + load_tile(b_warp_tile_, b_block_window); + } + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + [[maybe_unused]] ASmemBlockWindow& a_block_window, + [[maybe_unused]] BSmemBlockWindow& b_block_window) + { + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as correspoinding " + "C block tensor data type!"); + + // hot loop: + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + CWarpTensor c_warp_tensor; + + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + }); + + // Need to multiply aquant with accumulated C + // + // The accumulated C tile has the standard distribution. For example + // lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], + // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0], + // [26,0], [27,0]. + // + // These elements are in different rows, need to get the scale value + // for the corresponding row. + // Based on aquant's tile distribution, it can be inferred which + // lane holds the relevant scale. For example, the scales corresponding + // to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9, + // 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively. + // + // These scales can be obtained using __builtin_amdgcn_ds_bpermute. + + // MIters per warp + constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM; + + // Reg block offset based on mIter + constexpr index_t reg_block_offset = + ((mIter / mIters_per_warp) * Traits::AQPerBlock); + + constexpr index_t lane_base_offset = + (mIter % mIters_per_warp) * WarpGemm::kM; + + // Scale tensor offset along K + constexpr index_t src_reg_offset = reg_block_offset + kQScale; + + constexpr uint32_t kTileRows = 4; + constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; + + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) { + // Multiply by 4 because output is stored in tiles of 4 + // x CNLane + constexpr uint32_t row_base = + ((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) + + ((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane); + + constexpr uint32_t reg_offset_for_row_data = c_row / WarpGemm::kCMLane; + + // Lane index to source scale from + uint32_t src_lane_idx = lane_base_offset + row_base + + (__lane_id() / WarpGemm::kN * kTileRows); + + // Directly index into thread buffer corresponding to + // desired row coefficient + auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + // Pull scale data across lanes + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); + + c_block_tensor + .get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] += + (c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] * + scale_reg_f * kA_cvt_scale * kB_cvt_scale); + }); + }); + }); + }); + } + }; + + public: + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_(c_block_tensor, aq_block_tensor, a_block_window, b_block_window); + } + + private: + BlockGemmImpl block_gemm_impl_{}; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp new file mode 100644 index 0000000000..b1f89fe2e2 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp @@ -0,0 +1,679 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +struct AQuantGemmProblem +{ + CK_TILE_HOST AQuantGemmProblem() = default; + CK_TILE_HOST AQuantGemmProblem(index_t M_, + index_t N_, + index_t K_, + index_t QK_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + index_t stride_AQ_) + : M(M_), + N(N_), + K(K_), + QK(QK_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_C(stride_C_), + stride_AQ(stride_AQ_) + { + } + + index_t M; + index_t N; + index_t K; + index_t QK; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t stride_AQ; +}; + +struct AQuantGemmHostArgs : public AQuantGemmProblem +{ + CK_TILE_HOST AQuantGemmHostArgs() = default; + CK_TILE_HOST AQuantGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + const void* aq_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t QK_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + index_t stride_AQ_) + : AQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_AQ_), + a_ptr(a_ptr_), + b_ptr(b_ptr_), + aq_ptr(aq_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + const void* aq_ptr; + void* c_ptr; + index_t k_batch; +}; + +struct AQuantGemmKernelArgs +{ + const void* a_ptr; + const void* b_ptr; + const void* aq_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t QK; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t stride_AQ; + index_t k_batch; +}; + +template +struct AQuantGemmKernel +{ + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr AQuantGemmKernelArgs + MakeKernelArgs(const AQuantGemmHostArgs& hostArgs) + { + return AQuantGemmKernelArgs{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.aq_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.QK, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C, + hostArgs.stride_AQ, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const AQuantGemmKernelArgs& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + + if constexpr(std::is_same_v) + { + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const AQuantGemmKernelArgs& kargs) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + + static_assert(std::is_same_v); + if(kargs.QK % GemmPipeline::GetVectorSizeAQ() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + return false; + } + + if constexpr(std::is_same_v) + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + return false; + } + if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + return false; + } + } + else + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + return false; + } + if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + } + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + return true; + } + + template + CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + CDataType* c_ptr, + const AQuantGemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + const auto& aq_tensor_view = [&]() { + static_assert(std::is_same_v); + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.QK), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + }(); + + const auto& b_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& aq_pad_view = [&]() { + const auto& aq_tensor_view = views.at(I1); + static_assert(std::is_same_v); + return pad_tensor_view( + aq_tensor_view, + make_tuple(number{}, + number{}), + // TODO: Add support for padding. + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // TODO vector write in for C in ColMajor + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& aq_pad_view = views.at(I1); + const auto& b_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& aq_block_window = [&]() { + static_assert(std::is_same_v); + return make_tile_window( + aq_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + }(); + + const auto& b_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } + }(); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param aq_ptr input AQ pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + * @tparam DstInMemOp Destination memory operation (default: set). + */ + template + CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const AQuantGemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( + a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& aq_block_window = gemm_tile_windows.at(I1); + const auto& b_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const + { + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + // options + const ADataType* a_ptr = static_cast(kargs.a_ptr); + const BDataType* b_ptr = static_cast(kargs.b_ptr); + const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); + CDataType* c_ptr = static_cast(kargs.c_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + assert(kargs.k_batch == 1); + RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp new file mode 100644 index 0000000000..1356d7e222 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +template +struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +{ + using Base = GemmPipelineAgBgCrImplBase; + using ADataType = typename Base::ADataType; + using ALayout = typename Base::ALayout; + using BDataType = typename Base::BDataType; + using BLayout = typename Base::BLayout; + using BlockGemmShape = typename Base::BlockGemmShape; + + using AQLayout = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize; + + static_assert(KPerBlock % QuantGroupSize == 0, + "KPerBlock must be a multiple of QuantGroupSize"); + + // Create DRAM tile window for AQ + template + CK_TILE_DEVICE constexpr auto + GetAQDramLoadWindow(const AQDramBlockWindowTmp& aq_dram_block_window_tmp) const + { + static_assert(std::is_same_v); + + using YPerTile = number; + using XPerTile = number; + + auto aq_copy_dram_window = + make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile(), XPerTile()), + aq_dram_block_window_tmp.get_window_origin(), + Policy::template MakeAQDramTileDistribution()); + return aq_copy_dram_window; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..83b61e23fc --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "gemm_group_quant_utils.hpp" + +namespace ck_tile { + +struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPolicy +{ + using Base = UniversalGemmPipelineAgBgCrPolicy; + using Base::I0; + using Base::I1; + using Base::I2; + + using Base::ATileAccessPattern; + using Base::BTileAccessPattern; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() + { + using AQLayout = remove_cvref_t; + using AQDataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize; + + static_assert(std::is_same_v); + return GetAQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution() + { + using AQLayout = remove_cvref_t; + using BlockGemmShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize; + constexpr index_t VecLoadSize = GetVectorSizeAQ(); + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + + static_assert(std::is_same_v); + using TileEncodingPattern = TileDistributionEncodingPatternAQ; + + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0, + "KPerWarpGemm must be a multiple of kQuantGroupSize!"); + + using WarpGemm = WarpGemmMfmaDispatcher; + static_assert(std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v); + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return AQuantBlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp new file mode 100644 index 0000000000..9fb26eb4e0 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -0,0 +1,476 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BaseAQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +{ + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + if(has_hot_loop) + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + else + { + if(tail_number == ck_tile::TailNumber::Full) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Odd) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_number == ck_tile::TailNumber::Even) + { + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Unsupported tail number for this operation !!!"); + } + } + } +}; + +template +struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t AQPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize; + static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetVectorSizeAQ() + { + return Policy::template GetVectorSizeAQ(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "aquant_pipeline_AgBgCrCompV3", + concat('x', MPerBlock, NPerBlock, KPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK), + concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + constexpr index_t AQ_Buffer_Load_Inst_Num = + MPerBlock * KPerBlockAQ / (BlockSize * GetVectorSizeAQ()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", " + << "AQ vector size: " << GetVectorSizeAQ() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << ", " + << "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "QuantGroupSize: " << QuantGroupSize << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/AQ Dram block window should have the same data type as appropriate " + "([A|B|AQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_aq_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); + static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + "Aq block window has incorrect lengths for defined AqLayout!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; + + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + + auto block_gemm = BlockGemm(); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + AQBlockTile aq_block_tile[2]; + int currIdx = 0; + + auto c_block_tile = block_gemm.MakeCBlockTile(); + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr AQDramTileWindowStep aq_dram_tile_window_step = + is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); + + // DRAM prefetch (global read 0) + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch( + aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffled2DStaticTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + aq_copy_dram_window, + aq_dram_tile_window_step); + + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + + currIdx = (currIdx + 1) % 2; + + block_sync_lds(); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + } + else + { + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + aq_copy_dram_window, + aq_dram_tile_window_step); + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + currIdx = (currIdx + 1) % 2; + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm( + c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); + } + return c_block_tile; + } + }; + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp new file mode 100644 index 0000000000..4cca30fd3b --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" + +#include + +namespace ck_tile { + +template +struct GemmAQuantPipelineProblemBase : public GemmPipelineProblemBase +{ + using Base = GemmPipelineProblemBase; + + using Traits = typename Base::Traits; + + using typename Base::ADataType; + using typename Base::BDataType; + using typename Base::CDataType; + using typename Base::ComputeDataType; + using AQDataType = remove_cvref_t; + + using BlockGemmShape = typename Base::BlockGemmShape; + + using typename Base::ALayout; + using typename Base::BLayout; + using typename Base::CLayout; + + static constexpr bool TransposeC = false; + + using Base::kBlockSize; + + using Base::kPadK; + using Base::kPadM; + using Base::kPadN; + + using Base::DoubleSmemBuffer; + using Base::VectorLoadSize; + + using AQLayout = remove_cvref_t; + + static constexpr uint32_t kQuantGroupSize = QuantGroupSize_; + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + + static_assert(BlockGemmShape::kK % kQuantGroupSize == 0); + static_assert(Scheduler == GemmPipelineScheduler::Intrawave); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm_aquant_problem", + concat('x', VectorLoadSize, kBlockSize), + concat('x', kPadM, kPadN, kPadK), + Scheduler, + "QuantGroupSize", + kQuantGroupSize); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ() + { + static_assert(std::is_same_v); + return VectorLoadSize / sizeof(AQDataType); + } + + static constexpr index_t VectorSizeAQ = []() { + static_assert(std::is_same_v); + return kPadK ? 1 : GetAlignmentAQ(); + }(); +}; + +template +using GemmAQuantPipelineProblem = GemmAQuantPipelineProblemBase; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp new file mode 100644 index 0000000000..c018314ab7 --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE static constexpr auto GetAQGlobalVectorLoadSize() +{ + using I1 = number<1>; + constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t BlockSize = Problem::kBlockSize; + + // Data is replicated across warps along NWarps, so we divide BlockSize by NWarps + constexpr index_t elements_per_thread = (YPerTile * XPerTile) / (BlockSize / NWarps); + constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + + // Define vector load candidates in descending order of priority + constexpr std::array candidates{ + PackedSize * 32 / sizeof(DataType), + PackedSize * 16 / sizeof(DataType), + PackedSize * 8 / sizeof(DataType), + PackedSize * 4 / sizeof(DataType), + PackedSize * 2 / sizeof(DataType), + }; + + for(const auto vec_size : candidates) + { + if(vec_size <= 0 || XPerTile % vec_size != 0 || elements_per_thread % vec_size != 0) + continue; + bool is_valid = (vec_size > 0) && (XPerTile % vec_size == 0) && + (elements_per_thread % vec_size == 0) && vec_size != candidates[4]; + if(is_valid) + { + return vec_size; + } + } + return PackedSize; // Absolute fallback +} + +// AQ holds groupquant scale data for A. Data is loaded from DRAM and partitioned across +// threads. Post mfma scales are shuffled across threads in the warp and applied to +// accum registers. +template +struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern +{ + // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! + static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t num_warps = BlockSize / get_warp_size(); + + static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM); + + static_assert(num_warps == MWarps * NWarps * KWarps); + + // KWarps > 1 isn't supported + static_assert(KWarps == 1); + + // # of elements per thread + static constexpr index_t X = XPerTile; + + static constexpr index_t Y0 = 1; + static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1; + static constexpr index_t Y2 = MWarps; + static constexpr index_t Y3 = WarpGemm::kM; + static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp."); + static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, + "Y0, Y1, Y2, Y3 must cover the blocktile along Y."); + + CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 1>>, + tuple, sequence<0, 3>>, + sequence<1, 2>, + sequence<1, 0>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp new file mode 100644 index 0000000000..4972badb3f --- /dev/null +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_aquant_traits.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmAQuantTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + static constexpr int _VectorSize = 16; + + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; + using AQLayout = AQLayout_; + + static constexpr bool UseStructuredSparsity = false; + static constexpr index_t NumWaveGroups = 1; +}; + +} // namespace ck_tile diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 3e5a3034cd..8f3fbd52c5 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -17,3 +17,4 @@ add_subdirectory(topk_softmax) add_subdirectory(add_rmsnorm2d_rdquant) # add_subdirectory(layernorm2d) # add_subdirectory(rmsnorm2d) +add_subdirectory(gemm_block_scale) diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt new file mode 100644 index 0000000000..847ab88644 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -0,0 +1,19 @@ +set(TEST_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND TEST_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") + set(TEST_GEMM_NAME test_tile_gemm_aquant_basic) + set(QUANT_TYPES fp8 bf8 i4fp8 i4bf8 i4f32fp8 i4f32bf8) + + foreach(QUANT_TYPE ${QUANT_TYPES}) + add_gtest_executable(${TEST_GEMM_NAME}_${QUANT_TYPE} test_gemm_aquant_basic_${QUANT_TYPE}.cpp) + target_compile_options(${TEST_GEMM_NAME}_${QUANT_TYPE} PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + endforeach() + +else() + message(DEBUG "Skipping ck_tile quant gemm tests for current target") +endif() diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp new file mode 100644 index 0000000000..9c4277d879 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("bf8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp new file mode 100644 index 0000000000..b0cf55be6f --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("fp8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp new file mode 100644 index 0000000000..fd80bf2b06 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("i4bf8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp new file mode 100644 index 0000000000..fe8c9c5000 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("i4f32bf8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp new file mode 100644 index 0000000000..a319d9c2ad --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("i4f32fp8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp new file mode 100644 index 0000000000..ceb8760435 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_run_gemm_aquant_example.inc" + +int main() { return run_gemm_combinations("i4fp8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp new file mode 100644 index 0000000000..40f6712ef9 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp @@ -0,0 +1,681 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_group_quant.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE 5 + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(__gfx950__) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(__gfx950__) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +class ArgumentsNotSupportedException : public std::logic_error +{ + public: + explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {} +}; + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct GemmConfigPreshufle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshufle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + +template +struct GemmQuantTypeConfig +{ + using ADataType = ADataType_; + using QDataType = QDataType_; + using BDataType = BDataType_; + using AccDataType = float; + using CDataType = CDataType_; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using QDataType = float; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::half_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::pk_int4_t; + using QDataType = float; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template <> +struct GemmQuantTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using QDataType = float; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = float; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("aq_layout", "R", "Aq tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent") + .insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc new file mode 100644 index 0000000000..f410b58053 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc @@ -0,0 +1,577 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "test_gemm_aquant_utils.hpp" + +template +float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) +{ + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + static_assert(std::is_same_v); + + constexpr ck_tile::index_t M_Tile = 16; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 256; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 32; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmAQuantTraits; + + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + constexpr bool transposed_warp_gemm = false; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using CodegenPipelineProblem = + ck_tile::GemmAQuantPipelineProblem; + using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; + using Kernel = + ck_tile::AQuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(args.k_batch != 1) + { + throw std::runtime_error("split-k is not supported yet!"); + } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& aq_m_aqk_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t AQK, + ck_tile::index_t stride_A, + ck_tile::index_t stride_AQ, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::AQuantGemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.QK = AQK; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + args.stride_AQ = stride_AQ; + + float ave_time = gemm_calc_aquant( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK + + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B + << " StrideC =" << stride_C << " A_Layout =" << ALayout::name + << " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name + << " A_Type = " << DataTypeTraits::name + << " AQ_Type = " << DataTypeTraits::name + << " B_Type = " << DataTypeTraits::name + << " Acc_Type = " << DataTypeTraits::name + << " C_Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +template +bool run_gemm_test_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + using ADataType = typename TypeConfig::ADataType; + using AQDataType = typename TypeConfig::QDataType; + using BDataType = typename TypeConfig::BDataType; + using AccDataType = typename TypeConfig::AccDataType; + using CDataType = typename TypeConfig::CDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + if(K % QuantGroupSize != 0) + { + throw std::runtime_error("K must be aligned with QuantGroupSize"); + } + + ck_tile::index_t AQK = K / QuantGroupSize; + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor aq_m_aqk( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution fill_seed(0, 500); + + if(init_method == 0) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else if(init_method == 1) + { + std::cout << "Monotonic initialization is not supported." << std::endl; + return true; + } + else if(init_method == 2) + { + ck_tile::FillConstant{static_cast(0x22)}(a_m_k); + ck_tile::FillConstant{static_cast(0.5f)}(aq_m_aqk); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + } + else + { + a_m_k.SetZero(); + aq_m_aqk.SetZero(); + b_k_n.SetZero(); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm(a_m_k_dev_buf, + aq_m_aqk_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + AQK, + stride_A, + stride_AQ, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + if(!pass) + { + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl; + return false; + } + + return pass; +} + +template +bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for A."); + } + + return true; +} + +bool run_gemm_test(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(data_type == "fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4fp8") + { + using TypeConfig = decltype( + GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4bf8") + { + using TypeConfig = decltype( + GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32fp8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "i4f32bf8") + { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int run_gemm_combinations(std::string const& data_type) +{ + // Define possible values for each parameter + std::vector> mnk_values = {{ + "1", + "2048", + "5120", + }, + { + "2", + "2048", + "5120", + }, + { + "16", + "2048", + "5120", + }, + { + "17", + "2048", + "5120", + }, + { + "2047", + "5120", + "1024", + }, + { + "2048", + "5120", + "1024", + }}; + std::vector prec_values = {data_type}; + + // We'll store all our arguments as strings first + std::vector arg_strings = {"test_tile_gemm_aquant_basic", + "", // m placeholder + "", // n placeholder + "", // k placeholder + "", // prec placeholder + "-init=0", + "-v=1", + "-warmup=0", + "-repeat=1"}; + + // Create an array of const char pointers for argv + constexpr size_t ARG_COUNT = 9; + constexpr size_t ARG_MAX_LEN = 64; + char args[ARG_COUNT][ARG_MAX_LEN]; + char* argv[ARG_COUNT]; + + // Run all combinations + bool is_success = true; + for(const auto& mnk : mnk_values) + { + arg_strings[1] = "-m=" + mnk[0]; + arg_strings[2] = "-n=" + mnk[1]; + arg_strings[3] = "-k=" + mnk[2]; + + for(const auto& prec : prec_values) + { + arg_strings[4] = "-prec=" + prec; + + // Set up the argv array with pointers to the string data + for(size_t i = 0; i < ARG_COUNT; i++) + { + strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN); + argv[i] = args[i]; + } + + std::cout << "Arguments received: "; + for(size_t i = 1; i < ARG_COUNT; ++i) + { + std::cout << argv[i] << " "; + } + std::cout << std::endl; + + // Call the function with the current configuration + try + { + is_success = run_gemm_test(ARG_COUNT, argv) && is_success; + } + catch(const ArgumentsNotSupportedException& e) + { + std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n'; + // ArgumentsNotSupportedException is not an error. Do not change is_success + } + catch(const std::runtime_error& e) + { + std::cerr << "Caught runtime error: " << e.what() << '\n'; + is_success = false; + } + } + } + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +}