diff --git a/CHANGELOG.md b/CHANGELOG.md index 76fb46cdd9..8ae97b3d61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added int8 support for CK_TILE GEMM. * Added support for elementwise kernel. * Added benchmarking support for tile engine GEMM Multi D. +* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. ### Optimized diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 914fdac0e4..12cf874c73 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -8,9 +8,8 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp) target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - - add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp) - target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp) + target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index fc905790f1..6d6aec28c8 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -10,6 +10,7 @@ mkdir build && cd build ../script/cmake-ck-dev.sh ../ # The aquant pipeline method on the gemm calculation make tile_example_gemm_aquant_basic -j +make tile_example_gemm_bquant_basic -j ``` This will result in an executable `build/bin/tile_example_gemm_aquant_basic` diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp index 744c844040..d5a38fe754 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -27,8 +27,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr int kBlockPerCu = 1; - static_assert(std::is_same_v); constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; @@ -139,7 +137,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; @@ -207,7 +205,7 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::fp8_t, ck_tile::half_t, ck_tile::fp8_t>{}); - return run_gemm_example_prec_type, TypeConfig, 128>( + return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } else if(data_type == "i4bf8") @@ -216,7 +214,7 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::bf8_t, ck_tile::half_t, ck_tile::bf8_t>{}); - return run_gemm_example_prec_type, TypeConfig, 128>( + return run_gemm_example_prec_type, TypeConfig, 128>( a_layout, b_layout, argc, argv); } else diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp new file mode 100644 index 0000000000..991c4841e4 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" + +template +float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::stream_config& s) +{ + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + + static_assert(std::is_same_v); + + constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; + + constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; + + constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = ck_tile::TileGemmBQuantTraits; + + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + constexpr bool transposed_warp_gemm = false; + + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using CodegenPipelineProblem = + ck_tile::GemmBQuantPipelineProblem; + using CodegenGemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; + using Kernel = + ck_tile::BQuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(args.k_batch != 1) + { + throw std::runtime_error("split-k is not supported yet!"); + } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);; +} + +#include "run_gemm_bquant_example.inc" + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for B."); + } + + return 0; +} + +template