From af9520e598630b57757b14e7bdafa2664d548253 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Mon, 6 Oct 2025 15:43:23 -0400 Subject: [PATCH] fix: nil performance results for gemm examples (#2950) [ROCm/composable_kernel commit: 19415d0b6f7766e0523baad10ef0a53232b1defd] --- .../03_gemm/gemm_splitk_two_stage_invoker.hpp | 7 +- .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 27 ++- .../03_gemm/universal_gemm_invoker.hpp | 7 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 181 +++++++++--------- .../grouped_gemm_preshuffle.cpp | 150 +++++++-------- .../17_grouped_gemm/quant_grouped_gemm.cpp | 26 ++- example/ck_tile/18_flatmm/flatmm_basic.cpp | 36 ++-- 7 files changed, 208 insertions(+), 226 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index 8c7589dabb..9ece1638b5 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -252,15 +252,14 @@ struct SplitKTwoStageInvoker const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); } else { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index f200332588..dd13ed7bba 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -275,30 +275,29 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { // For workspace mode, always use SET operation since each K-split writes to separate memory - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } /** diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index e0d97a50db..d0fd69b1e2 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -185,15 +185,14 @@ struct UniversalInvoker const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); } else { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } }; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 606d98d9e2..f5335c3ec0 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -70,99 +70,95 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - - return ave_time; - }; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(gemm_descs[0].k_batch == 1) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } template ( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); }; if(!splitk) { - Run(ck_tile::integral_constant{}); + return ave_time = Run(ck_tile::integral_constant{}); } else { - Run(ck_tile::integral_constant{}); + return ave_time = + Run(ck_tile::integral_constant{}); } - - return ave_time; } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index 4ce55e8e72..b9d6a4a1bc 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -76,99 +76,95 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - - return ave_time; - }; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(gemm_descs[0].k_batch == 1) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 409bb173a1..64c9dda64a 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -109,23 +109,19 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); }; - Run(ck_tile::integral_constant{}); - - return ave_time; + return ave_time = Run(ck_tile::integral_constant{}); } #include "quant_run_grouped_gemm_example.inc" diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 280da8d333..3273fac674 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -167,38 +167,38 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } template