diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index d9a6c2bac0..7358d4d749 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -8,12 +8,6 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - - add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp) - target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - - add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp) - target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 36f7995c14..9acc4f9bfc 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -3,6 +3,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming implementation. - AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline +- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline - Row and Column-wise scaled: scaling implemented in Epilogue ## build @@ -38,3 +39,13 @@ args: -timer gpu:gpu timer, cpu:cpu timer (default:gpu) -quant_mode Which quant method to use (aquant, rowcol) ``` + +User need to select correct mapping of config for each quant mode: + +| | quant_mode as runtime argument | Config in cpp file | +|:--------|:-----:|-------| +| For selecting AQuant | aquant | GemmConfigQuant | +| For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant | +| For selecting BQuant | bquant | GemmConfigQuant | +| For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant | + diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp deleted file mode 100644 index 799e4a0e73..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp +++ /dev/null @@ -1,234 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include -#include - -#include "gemm_utils.hpp" - -template -float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) -{ - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - - static_assert(std::is_same_v); - - constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; - constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; - constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; - - constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; - constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; - constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; - - constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; - constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; - constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; - - using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using CodegenGemmTraits = ck_tile::TileGemmQuantTraits; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - - using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; - - const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - constexpr bool transposed_warp_gemm = false; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - - using CodegenPipelineProblem = - ck_tile::GemmAQuantPipelineProblem; - using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - transposed_warp_gemm, - ck_tile::memory_operation_enum::set>>; - using Kernel = ck_tile::QuantGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); -} - -#include "run_gemm_aquant_example.inc" - -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for A."); - } - - return 0; -} - -template