From b77cfe1ad5cc919768bd2d990b7be4da3c2213f2 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Mon, 5 May 2025 18:46:44 +0200 Subject: [PATCH] [CK_TILE] Remove scratch usage from universal gemm (#2001) * moves kbatch condition outside of kernel * add reviewer comments * fixes * fix tests * fixes after review --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> [ROCm/composable_kernel commit: 0bcb804ad079f8b427786cc701675b3c535a180b] --- example/ck_tile/03_gemm/gemm_basic.cpp | 91 ++++++---- example/ck_tile/03_gemm/universal_gemm.cpp | 88 ++++++--- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 171 ++++++++++-------- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 171 ++++++++++-------- .../ops/epilogue/cshuffle_epilogue.hpp | 63 +++---- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 10 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 53 ++---- .../batched_gemm/test_batched_gemm_util.hpp | 34 +++- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 98 ++++++---- .../grouped_gemm/test_grouped_gemm_util.hpp | 34 +++- 10 files changed, 473 insertions(+), 340 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 69051423fb..1edb3da947 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -53,50 +53,67 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - if(!Kernel::IsSupportedArgument(kargs)) + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + return Run(ck_tile::integral_constant{}); } - - if(s.log_level_ > 0) + else { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + return Run(ck_tile::integral_constant{}); } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; } #include "run_gemm_example.inc" diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 2ba16ca89d..e6a2811918 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -61,10 +61,13 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -116,23 +120,40 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -146,20 +167,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -167,7 +189,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -175,7 +198,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -183,7 +207,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -191,7 +216,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -199,20 +225,22 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } #endif } @@ -220,18 +248,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index a0cd18ec74..0219c67305 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -106,61 +106,81 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } - - if(s.log_level_ > 0) + else { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; }; if(has_hot_loop) @@ -168,18 +188,18 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -193,20 +213,21 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -214,7 +235,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -222,7 +244,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -230,7 +253,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -238,7 +262,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -246,20 +271,22 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } #endif } @@ -267,18 +294,18 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } std::ostringstream err; err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but " diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 2a9903362d..9b134ff779 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -114,66 +114,86 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); - const dim3 grids = Kernel::GridSize(gemm_descs); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + constexpr dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(p_workspace_), + gemm_descs.size())); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } - - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(p_workspace_), - gemm_descs.size())); - return ave_time; }; if(has_hot_loop) @@ -181,18 +201,18 @@ float grouped_gemm(const std::vector& gemm_descs, #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -206,20 +226,21 @@ float grouped_gemm(const std::vector& gemm_descs, // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -227,7 +248,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -235,7 +257,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -243,7 +266,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -251,7 +275,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -259,20 +284,22 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } #endif } diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 225997439e..9b8dde1905 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -22,23 +22,25 @@ template + bool isCTransposed_, + memory_operation_enum MemoryOperation_> struct CShuffleEpilogueProblem { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; - static constexpr index_t kMWave = kMWave_; - static constexpr index_t kNWave = kNWave_; - static constexpr index_t kMPerXdl = kMPerXdl_; - static constexpr index_t kNPerXdl = kNPerXdl_; - static constexpr index_t kKPerXdl = kKPerXdl_; - static constexpr index_t isCTransposed = isCTransposed_; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t kMWave = kMWave_; + static constexpr index_t kNWave = kNWave_; + static constexpr index_t kMPerXdl = kMPerXdl_; + static constexpr index_t kNPerXdl = kNPerXdl_; + static constexpr index_t kKPerXdl = kKPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; + static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; }; template @@ -52,18 +54,19 @@ struct CShuffleEpilogue // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; - using CLayout = remove_cvref_t; - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kMPerBlock = Problem::kMPerBlock; - static constexpr index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t kMWave = Problem::kMWave; - static constexpr index_t kNWave = Problem::kNWave; - static constexpr index_t kMPerXdl = Problem::kMPerXdl; - static constexpr index_t kNPerXdl = Problem::kNPerXdl; - static constexpr index_t kKPerXdl = Problem::kKPerXdl; - static constexpr index_t isCTransposed = Problem::isCTransposed; - static constexpr index_t kMPerIteration = kMPerXdl * kMWave; - static constexpr index_t kNPerIteration = kNPerXdl * kNWave; + using CLayout = remove_cvref_t; + static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMPerBlock = Problem::kMPerBlock; + static constexpr index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t kMWave = Problem::kMWave; + static constexpr index_t kNWave = Problem::kNWave; + static constexpr index_t kMPerXdl = Problem::kMPerXdl; + static constexpr index_t kNPerXdl = Problem::kNPerXdl; + static constexpr index_t kKPerXdl = Problem::kKPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr index_t kMPerIteration = kMPerXdl * kMWave; + static constexpr index_t kNPerIteration = kNPerXdl * kNWave; using WG = WarpGemmMfmaDispatcher + template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) { @@ -179,7 +180,7 @@ struct CShuffleEpilogue const auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index dfb6bfae58..d495c0d950 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -142,15 +142,7 @@ struct BatchedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); - } - else - { - this->template RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); - } + this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index bc41f680f2..9c25104cd7 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -608,9 +608,7 @@ struct GemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, @@ -622,7 +620,8 @@ struct GemmKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + MakeGemmTensorViews( + a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -640,9 +639,8 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr_0); } /** @@ -660,9 +658,7 @@ struct GemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, @@ -675,7 +671,8 @@ struct GemmKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + MakeGemmTensorViews( + a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -692,9 +689,8 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr_0); } CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const @@ -718,7 +714,9 @@ struct GemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GetSmemSize()]; - if(kargs.k_batch == 1) + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunGemm2LDS(a_ptr, b_ptr, @@ -730,38 +728,15 @@ struct GemmKernel i_m, i_n); } - else - { - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } } else { - if(kargs.k_batch == 1) + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); } - else - { - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); - } - } } } }; diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 0af3ef3b34..4633f23ded 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -81,10 +81,13 @@ class TestCkTileBatchedGemm : public ::testing::Test float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -138,11 +142,29 @@ class TestCkTileBatchedGemm : public ::testing::Test return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 1b997ddbce..0329f16416 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -138,9 +138,12 @@ class TestCkTileGemmPipeline : public ::testing::Test const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -193,15 +197,32 @@ class TestCkTileGemmPipeline : public ::testing::Test s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if constexpr(PipelineType == GemmPipelineType::CompV3) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -219,69 +240,69 @@ class TestCkTileGemmPipeline : public ::testing::Test // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 3) { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 4) { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 5) { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 6) { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 7) { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } } @@ -290,15 +311,15 @@ class TestCkTileGemmPipeline : public ::testing::Test { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } } @@ -307,7 +328,8 @@ class TestCkTileGemmPipeline : public ::testing::Test // Tail number always Full - #PrefetchStages if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index b125d19762..3dec229643 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -102,10 +102,13 @@ class TestCkTileGroupedGemm : public ::testing::Test float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); @@ -164,11 +168,29 @@ class TestCkTileGroupedGemm : public ::testing::Test return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else