diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 825cd6e522..d2112a67bf 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -2,6 +2,7 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp) +add_executable(tile_example_gemm_splitk_two_stage EXCLUDE_FROM_ALL gemm_splitk_two_stage.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) @@ -16,3 +17,4 @@ target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OP target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 99c943a7f1..d687e35f5d 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -2,185 +2,9 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_utils.hpp" - -template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) - -{ - if constexpr(Persistent) - std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; - - // This part comes from the Codegen - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; - -#if CK_TILE_USE_WMMA - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 16; - constexpr ck_tile::index_t N_Warp_Tile = 16; - constexpr ck_tile::index_t K_Warp_Tile = 16; -#else - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; -#endif - - using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using CodegenGemmTraits = ck_tile::TileGemmTraits; - - using CodegenPipelineProblem = ck_tile:: - GemmPipelineProblem; - - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - CodegenPipelineProblem::TransposeC, - memory_operation>>; - - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } -} - #include "run_gemm_example.inc" - -template -int run_gemm_example_prec_type(std::string a_layout, - std::string b_layout, - ck_tile::ArgParser& arg_parser) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if constexpr(std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices when " - "BPrecType is ck_tile::pk_int4_t!"); - } - } - else - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Col{}, Row{}); - } - else if(a_layout == "R" && b_layout == "R") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Row{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Row{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } -} +#include "run_gemm_example_common.hpp" +#include "gemm_basic_invoker.hpp" int run_gemm_example(ck_tile::ArgParser& arg_parser) { @@ -188,36 +12,53 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); + using GemmConfig = GemmConfigBase; + using Invoker = BasicInvoker; + if(data_type == "fp16") { - return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type( + a_layout, b_layout, arg_parser); } else if(data_type == "bf16") { - return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type( + a_layout, b_layout, arg_parser); } else if(data_type == "fp8") { - return run_gemm_example_prec_type( - a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "bf8") { - return run_gemm_example_prec_type( - a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "i8") { - return run_gemm_example_prec_type( - a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) { - return run_gemm_example_prec_type( - a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else { @@ -232,7 +73,9 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + auto arg_parser = create_args(); + auto result = arg_parser.parse(argc, argv); + if(!result) return -1; diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp new file mode 100644 index 0000000000..861374e268 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "gemm_utils.hpp" + +struct BasicInvoker +{ + template + static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { + if constexpr(Persistent) + { + std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; + } + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + +#if CK_TILE_USE_WMMA + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 16; +#else + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; +#endif + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = ck_tile::TileGemmTraits; + + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC, + memory_operation>>; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; + + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = + std::make_unique>( + kargs.as_ptr[0], + kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; + } + else + { + preprocess = clear_gemm_output; + } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + }; + + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } + } +}; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp new file mode 100644 index 0000000000..0455e8e34d --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" +#include "run_gemm_example_common.hpp" +#include "gemm_splitk_two_stage_invoker.hpp" + +int run_gemm_example(ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + using Invoker = SplitKTwoStageInvoker; + + if(data_type == "fp16") + { + return run_gemm_example_prec_type, + Invoker, + ck_tile::half_t>(a_layout, b_layout, arg_parser); + } + else if(data_type == "bf16") + { + return run_gemm_example_prec_type, + Invoker, + ck_tile::bf16_t>(a_layout, b_layout, arg_parser); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int main(int argc, char* argv[]) +{ + auto arg_parser = create_args(); + auto result = arg_parser.parse(argc, argv); + + if(!result) + return -1; + + try + { + return !run_gemm_example(arg_parser); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp new file mode 100644 index 0000000000..21867816e2 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "gemm_utils.hpp" +#include "ck_tile/ops/elementwise.hpp" + +template +struct GemmConfigTwoStage : public GemmConfigComputeV3 +{ + using WorkspaceType = ck_tile::remove_cvref_t; +}; + +struct SplitKTwoStageInvoker +{ + template + static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + + { + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using WorkspaceType = ck_tile::remove_cvref_t; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using GemmKernel = ck_tile::GemmKernel; + + ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); + ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); + auto c_ptr = ws_args.c_ptr; + ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); + + const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) + : GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); + + if(!GemmKernel::IsSupportedArgument(gemm_kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; + + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; + + ck_tile::index_t total_elements = 1; + std::vector shape = {args.M, args.N}; + + for(auto d : shape) + total_elements *= d; + + constexpr ck_tile::index_t kBlockSize = + ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + constexpr ck_tile::index_t kBlockPerCu = 1; + + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = + (total_elements + elements_per_block - 1) / elements_per_block; + + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); + auto input_size = ck_tile::make_tuple(args.M, args.N); + + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; + + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); + }; + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = + std::make_unique>( + gemm_kargs.as_ptr[0], + gemm_kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; + } + else + { + preprocess = clear_gemm_output; + } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel( + GemmKernel{}, grids, blocks, 0, gemm_kargs), + ck_tile::make_kernel(ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(args.N, 1), // Input Stride + ck_tile::make_tuple(args.N, 1), // Output Stride + input_tensors, + static_cast(c_ptr))); + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + } + else + { + Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; + } +}; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index f42135a0b5..324dfc069a 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -608,16 +608,11 @@ template -int run_gemm_example_with_layouts_two_stage(int argc, - char* argv[], +int run_gemm_example_with_layouts_two_stage(ck_tile::ArgParser& arg_parser, const ALayout a_layout = ALayout{}, const BLayout b_layout = BLayout{}, [[maybe_unused]] const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - using AccDataType = typename GemmTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); @@ -837,12 +832,13 @@ template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +int run_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - auto [result, arg_parser] = create_args(argc, argv); - bool preshuffle = GemmConfig::Preshuffle; + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + bool preshuffle = GemmConfig::Preshuffle; if(preshuffle && std::is_same_v) { @@ -866,7 +862,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a CPrecType, Row, Col, - Row>(argc, argv, Row{}, Col{}, Row{}); + Row>(arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { @@ -876,7 +872,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a CPrecType, Col, Col, - Row>(argc, argv, Col{}, Col{}, Row{}); + Row>(arg_parser, Col{}, Col{}, Row{}); } else { @@ -892,7 +888,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a APrecType, BPrecType, CPrecType>( - argc, argv, Row{}, Row{}, Row{}); + arg_parser, Row{}, Row{}, Row{}); } if(a_layout == "R" && b_layout == "C") { @@ -900,7 +896,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a APrecType, BPrecType, CPrecType>( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { @@ -908,7 +904,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a APrecType, BPrecType, CPrecType>( - argc, argv, Col{}, Row{}, Row{}); + arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { @@ -916,7 +912,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a APrecType, BPrecType, CPrecType>( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -927,12 +923,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } template