diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 12cf874c73..d9a6c2bac0 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -6,8 +6,12 @@ endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp) - target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) + target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp) + target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp) target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 6d6aec28c8..36f7995c14 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -1,18 +1,21 @@ -# GEMM Matrix Multiplication +# Quant GEMM Matrix Multiplication -This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation. +This folder contains examples of quant GEMMs using the ck_tile tile-programming implementation. + +- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline +- Row and Column-wise scaled: scaling implemented in Epilogue ## build ``` # in the root of ck_tile mkdir build && cd build -# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +# you can replace with the appropriate architecture (for example gfx942) or leave it blank ../script/cmake-ck-dev.sh ../ -# The aquant pipeline method on the gemm calculation -make tile_example_gemm_aquant_basic -j +# Compile the quant kernels +make tile_example_gemm_quant_basic -j make tile_example_gemm_bquant_basic -j ``` -This will result in an executable `build/bin/tile_example_gemm_aquant_basic` +This will result in an executable `build/bin/tile_example_gemm_quant_basic` ## example ``` @@ -22,15 +25,16 @@ args: -n n dimension (default:2048) -k k dimension (default:64) -a_layout Tensor A data layout (default: R) - -b_layout Tensor B data layout (default: R) + -b_layout Tensor B data layout (default: C) -c_layout Tensor C data layout (default: R) -stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) - -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1) -e Absolute error tolerance (default:1e-5) - -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) + -prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8) -warmup number of iterations before benchmark the kernel (default:10) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -quant_mode Which quant method to use (aquant, rowcol) ``` diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp deleted file mode 100644 index d5a38fe754..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp +++ /dev/null @@ -1,226 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include -#include - -#include "gemm_utils.hpp" - -template -float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s) -{ - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - static_assert(std::is_same_v); - - constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; - constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; - constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; - - constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; - constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; - constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; - - constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; - constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; - constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; - - using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - - using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; - - const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - constexpr bool transposed_warp_gemm = true; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - - using CodegenPipelineProblem = - ck_tile::GemmAQuantPipelineProblem; - using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - transposed_warp_gemm, - ck_tile::memory_operation_enum::set>>; - using Kernel = - ck_tile::AQuantGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); -} - -#include "run_gemm_aquant_example.inc" - -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for A."); - } - - return 0; -} - -template