diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index 62744d9895..c312a53c2a 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -46,14 +46,6 @@ struct SplitKTwoStageInvoker GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using WorkspaceType = ck_tile::remove_cvref_t; - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - - using WorkspaceType = ck_tile::remove_cvref_t; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( @@ -244,21 +217,15 @@ struct SplitKTwoStageInvoker ck_tile::make_tuple(args.N, 1), // Output Stride input_tensors, static_cast(c_ptr))); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index 74edddb6c9..abad4ab5c4 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -133,14 +133,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - // Create base GEMM arguments pointing to workspace instead of final output // The workspace will store partial results from each K-split ck_tile::GemmHostArgs base_args(args.a_ptr, @@ -179,23 +158,18 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& args.stride_A, args.stride_B, args.stride_E); + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&]() { + // use SET operation since each K-split writes to separate memory + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + scheduler>; using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; @@ -276,29 +250,20 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - return ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - return ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - // For workspace mode, always use SET operation since each K-split writes to separate memory - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - }; - - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return Run(); } /** diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp index 07f449f34b..b394598110 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp @@ -33,14 +33,6 @@ struct WeightPreshuffleInvoker GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem{}); - } - else - { - throw std::runtime_error("split-k is not supported yet!"); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("split-k is not supported yet!"); + } } }; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 30cb3d3476..c4f100b36b 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -63,14 +63,17 @@ void permute_tensor_b(Tensor& tensor) GemmConfig::TransposeC, GemmConfig::UseStructuredSparsity>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index b9b05a8e86..0fcf9680bc 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -34,14 +34,6 @@ struct UniversalInvoker GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + constexpr auto scheduler = GemmConfig::Scheduler; - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 6838e899e6..c7e37bc8a7 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -59,7 +59,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - float ave_time{0}; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } #include "run_batched_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 531e437006..3ff3f2f10e 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -42,12 +42,6 @@ float grouped_gemm(const std::vector& gemm_descs, GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits& gemm_descs, BLayout, CLayout, GemmConfig::TransposeC>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + constexpr auto scheduler = GemmConfig::Scheduler; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - float ave_time{0}; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); }; - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } template & gemm_d GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits& gemm_d BLayout, ELayout, GemmConfig::TransposeC>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem& gemm_d << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - - return ave_time; + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(gemm_descs[0].k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } template & gemm_descs, GemmConfig::TileParitionerGroupNum, GemmConfig::TileParitionerM01>; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits& gemm_descs, GemmConfig::Persistent, GemmConfig::NumWaveGroups, GemmConfig::Preshuffle>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = - // if preshuffle == true then num_loop is recalculated for each group in the kernel code - TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; - float ave_time{0}; - - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); }; - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } template ; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + constexpr auto scheduler = GemmConfig::Scheduler; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - float ave_time{0}; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } #include "run_gemm_multi_d_fp16_example.inc" diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index 7638b92002..d2663b033c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -57,43 +57,9 @@ struct GroupedConvolutionBackwardDataInvoker GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; + constexpr auto scheduler = ConvConfig::Scheduler; - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - OutDataType, - WeiDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - InDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; - - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< @@ -103,8 +69,6 @@ struct GroupedConvolutionBackwardDataInvoker GemmShape, GemmUniversalTraits, scheduler, - has_hot_loop_v, - tail_number_v, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, InDataType, @@ -170,26 +134,19 @@ struct GroupedConvolutionBackwardDataInvoker kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index f7171ef9d9..0891e8c20b 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -57,43 +57,9 @@ struct GroupedConvolutionBackwardWeightInvoker GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; + constexpr auto scheduler = ConvConfig::Scheduler; - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; - - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), - args.output_spatial_lengths_.end(), - 1, - std::multiplies()); - - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< @@ -103,8 +69,6 @@ struct GroupedConvolutionBackwardWeightInvoker GemmShape, GemmUniversalTraits, scheduler, - has_hot_loop_v, - tail_number_v, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, @@ -176,26 +140,19 @@ struct GroupedConvolutionBackwardWeightInvoker } }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index 5d78bc4739..50c0ce4f87 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -60,42 +60,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + constexpr auto scheduler = ConvConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), - args.output_spatial_lengths_.end(), - 1, - std::multiplies()); - - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< @@ -105,8 +72,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker GemmShape, GemmUniversalTraits, scheduler, - has_hot_loop_v, - tail_number_v, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, @@ -209,7 +174,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' << "pipeline: " << GemmPipeline::GetName() << '\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z @@ -228,7 +192,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), @@ -242,22 +206,15 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ck_tile::make_tuple(shape[1], 1), // Output Stride input_tensors, static_cast(c_ptr))); - - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); - } - else - { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 3e1f4c6268..82541bb593 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -65,148 +65,96 @@ struct GroupedConvolutionForwardInvoker GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; - - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - - // Split-K parameters - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; + constexpr auto scheduler = ConvConfig::Scheduler; // ===================================================================== // Regular Convolution: Simple, no split-image // ===================================================================== - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - has_hot_loop_v, - tail_number_v, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - - // ===================================================================== - // Split-K lambda - // ===================================================================== - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); } - else + + if(s.log_level_ > 0) { - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; } + + return ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; // ===================================================================== - // Regular Convolution Example: ALWAYS uses regular path (Kernel) + // Split-K dispatch // ===================================================================== - // This example demonstrates regular convolution without split-image. - // For large images that don't fit in memory, use - // grouped_convolution_forward_split_image.cpp - - // Launch kernel using regular path (no split-image) - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index d154d8710b..4261385a84 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -72,36 +72,6 @@ struct GroupedConvolutionForwardInvoker GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd< - ConvConfig::NumWaveGroups>, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsTypeDefault::VectorSizeA, - GroupedConvTraitsTypeDefault::VectorSizeB>; - - using BaseGemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t gemm_k = - args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - - // Split-K parameters - const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - using TransformType = ck_tile::TransformConvFwdToGemm{}); - else - return Run(has_hot_loop_, - tail_number_, - MemoryOpAtomicAdd{}, - ck_tile::bool_constant{}); - }; - return BaseGemmPipeline::TailHandler(RunSplitImage, has_hot_loop, tail_num); + if(args.k_batch == 1) + return Run(MemoryOpSet{}, ck_tile::bool_constant{}); + else + return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); } else { - const auto RunRegular = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - return Run(has_hot_loop_, - tail_number_, - MemoryOpSet{}, - ck_tile::bool_constant{}); - else - return Run(has_hot_loop_, - tail_number_, - MemoryOpAtomicAdd{}, - ck_tile::bool_constant{}); - }; - return BaseGemmPipeline::TailHandler(RunRegular, has_hot_loop, tail_num); + if(args.k_batch == 1) + return Run(MemoryOpSet{}, ck_tile::bool_constant{}); + else + return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); } } }; diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp index 5ea4299492..acb9126d65 100644 --- a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -63,8 +63,6 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = GemmConfig::Scheduler; - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; - float ave_time{0}; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GemmKernelMultiABD; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } #include "run_gemm_multi_abd_fp16_example.inc" diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp index 6536894394..f9f13c6e85 100644 --- a/example/ck_tile/41_batched_contraction/batched_contraction.cpp +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -90,24 +90,9 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs; - using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; - ck_tile::index_t K_total = 1; - for(ck_tile::index_t i = NumDimG + NumDimM; i < NumDimG + NumDimM + NumDimK; ++i) - { - K_total *= args.A_dims[i]; - } - - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_total); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = [&]() { constexpr auto memory_operation = ck_tile::memory_operation_enum::set; // Always set (no atomic_add) @@ -116,9 +101,7 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs; + scheduler>; using GemmPipeline = GEMM_PIPELINE; @@ -166,14 +149,10 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs(Kernel{}, grids, blocks, 0, kargs); - ave_time = ck_tile::launch_kernel(s, kernel); - - return ave_time; + return ck_tile::launch_kernel(s, kernel); }; - BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); - - return ave_time; + return Run(); } #define HANDLE_CASE(G, M, N, K) \ diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp index 6b18095544..80e8ae8d98 100644 --- a/experimental/builder/test/test_bwd_data_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -54,8 +54,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) GemmShape, GemmUniversalTraits, ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, - true /*has_hot_loop_v*/, - ck_tile::TailNumber::Full /*tail_number_v*/, ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, ck_tile::bf16_t /*InDataType*/, diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index 3ecd06e33d..9b3cd169bb 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -156,8 +156,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) GemmShape, GemmUniversalTraits, ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, - true /*has_hot_loop_v*/, - ck_tile::TailNumber::Full /*tail_number_v*/, ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, ck_tile::bf16_t /*WeiDataType*/, diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 9da707bfec..6a8f1f14e3 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -767,8 +767,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) GemmShape, GemmUniversalTraits, ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, - true /*has_hot_loop_v*/, - ck_tile::TailNumber::Full /*tail_number_v*/, ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, ck_tile::bf16_t /*OutDataType*/, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index d27f937435..0b2cdde05e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -19,12 +19,12 @@ struct BaseGemmPipelineAgBgCrCompAsync static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = 1; - CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; } - CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { if(num_loop == 1) { @@ -158,9 +158,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; static constexpr auto is_b_load_tr_v = bool_constant{}; @@ -539,14 +537,21 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem_0, - p_smem_1); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } public: @@ -557,14 +562,21 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}.template operator()( - a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, - b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, - num_loop, - p_smem_0, - p_smem_1); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem_0, + p_smem_1); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index f83462391c..d4475e8c60 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -154,10 +154,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr index_t Preshuffle = Problem::Preshuffle; - static constexpr bool HasHotLoop = - Problem::HasHotLoop; // Base::BlockHasHotloop(Problem::num_loop); - static constexpr auto TailNum = - Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop); static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto is_a_load_tr_v = bool_constant{}; @@ -641,13 +637,20 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 index_t num_loop, void* p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } /** @@ -700,13 +703,15 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 index_t num_loop, void* p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const BDataType& b) { e = b; }, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + return operator()(a_dram_block_window_tmp, + b_dram_block_window_tmp, + num_loop, + has_hot_loop, + tail_number, + p_smem); } template static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr index_t Preshuffle = Problem::Preshuffle; - static constexpr bool HasHotLoop = Problem::HasHotLoop; - static constexpr auto TailNum = Problem::TailNum; - static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; @@ -685,14 +683,21 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 void* p_smem_0, void* p_smem_1) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem_0, - p_smem_1); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template void* __restrict__ p_smem_0, void* __restrict__ p_smem_1) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const BDataType& b) { e = b; }, - num_loop, - p_smem_0, - p_smem_1); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem_0, + p_smem_1); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; static constexpr index_t Preshuffle = Problem::Preshuffle; - static constexpr bool HasHotLoop = Problem::HasHotLoop; - static constexpr auto TailNum = Problem::TailNum; - static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto Scheduler = Problem::Scheduler; static constexpr index_t NumWarps = BlockGemmShape::NumWarps; static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{}); @@ -404,13 +402,20 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 index_t num_loop, void* p_smem_0) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem_0); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template const index_t num_loop, void* __restrict__ p_smem_0) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const BDataType& b) { e = b; }, - num_loop, - p_smem_0); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem_0); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template PrefetchStages; } - CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { if(num_loop % HotloopUnroll == 1) { @@ -153,9 +153,7 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; static constexpr index_t Preshuffle = Problem::Preshuffle; - static constexpr bool HasHotLoop = Problem::HasHotLoop; - static constexpr auto TailNum = Problem::TailNum; - static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; @@ -173,11 +171,9 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 return concat('_', "pipeline_AgBgCrCompV6", BlockSize, concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), concat('x', kPadM, kPadN, kPadK), - concat('x', TailNum), concat('_', KRepeat), concat('_', DoubleSmemBuffer), - concat('_', Preshuffle), - concat('_', HasHotLoop)); + concat('_', Preshuffle)); // clang-format on } @@ -725,13 +721,20 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 index_t num_loop, void* __restrict__ p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template const index_t num_loop, void* __restrict__ p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const BDataType& b) { e = b; }, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr index_t Preshuffle = Problem::Preshuffle; - // Where is the right place for HasHotLoop and TailNum ??? - static constexpr bool HasHotLoop = Problem::HasHotLoop; - static constexpr auto TailNum = Problem::TailNum; - static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; @@ -887,13 +884,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem index_t num_loop, void* p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - a_element_func, - b_dram_block_window_tmp, - b_element_func, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template index_t num_loop, void* p_smem) const { - return PipelineImpl{}.template operator()( - a_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - b_dram_block_window_tmp, - [](auto& e, const ADataType& a) { e = a; }, - num_loop, - p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } template = DsReadPreload) ? DsReadPreload : MIterPerWarp * KIterPerWarp; - static constexpr auto TailNum = Problem::TailNum; #ifdef __gfx942__ static constexpr index_t mfma_per_wg = 2; @@ -1042,13 +1041,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 void* p_smem_ping, void* p_smem_pong) const { - return operator()( - a_dram_block_window_tmp[number<0>{}], - [](const ADataType& a) { return a; }, - b_flat_dram_block_window_tmp[number<0>{}], - num_loop, - p_smem_ping, - p_smem_pong); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto bool_val, auto tail_num_) { + (void)bool_val; // Suppress unused parameter warning + constexpr auto tail_num = tail_num_.value; + constexpr auto PassThrough = [](const ADataType& a) { return a; }; + return operator()(a_dram_block_window_tmp[number<0>{}], + PassThrough, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem_ping, + p_smem_pong); + }; + return Base::TailHandler(RunPipeline, true, tail_number); } // called from general gemm kernel @@ -1063,13 +1069,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 void* p_smem_ping, void* p_smem_pong) const { - return operator()( - a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, - b_flat_dram_block_window_tmp, - num_loop, - p_smem_ping, - p_smem_pong); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + + const auto RunPipeline = [&](auto bool_val, auto tail_num_) { + (void)bool_val; // Suppress unused parameter warning + constexpr auto tail_num = tail_num_.value; + constexpr auto PassThrough = [](const ADataType& a) { return a; }; + return operator()(a_dram_block_window_tmp, + PassThrough, + b_flat_dram_block_window_tmp, + num_loop, + p_smem_ping, + p_smem_pong); + }; + return Base::TailHandler(RunPipeline, true, tail_number); } // called from grouped gemm kernel diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 3c344259bb..77eb416532 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -81,7 +81,6 @@ class TestCkTileBatchedGemm : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -154,36 +135,26 @@ class TestCkTileBatchedGemm : public ::testing::Test { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' << "pipeline: " << GemmPipeline::GetName() << '\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index c489d3be54..a0c078a1e9 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -159,8 +159,6 @@ class TestCkTileGemmPipeline : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using BaseGemmPipeline = - typename GemmPipelineTypeSelector::base_pipeline; + using GemmPipeline = + typename GemmPipelineTypeSelector::pipeline; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = - typename GemmPipelineTypeSelector::pipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index 8234692696..ee045c7f48 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -134,7 +134,6 @@ class TestCkTileGemmMultiABD : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - float ave_time{0}; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< ck_tile::DefaultGemm2DEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - std::cout << "Run without SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - std::cout << "Run using SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + std::cout << "Run without SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } + else + { + std::cout << "Run using SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 373370b18c..8217f5a3d9 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -150,7 +150,6 @@ class TestCkTileGemmMultiD : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - float ave_time{0}; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< ck_tile::DefaultGemm2DEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - std::cout << "Run without SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - std::cout << "Run using SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + std::cout << "Run without SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } + else + { + std::cout << "Run using SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 928c72b62d..43a73738d9 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -132,8 +132,6 @@ class TestCkTileGemmPipeline : public ::testing::Test GemmConfig::K_Warp_Tile>>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - - using Traits = ck_tile::TileGemmTraits; static constexpr bool StructuredSparsity = false; static constexpr bool NumWaveGroup = 1; @@ -150,37 +148,19 @@ class TestCkTileGemmPipeline : public ::testing::Test NumWaveGroup, preshuffle>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using BaseGemmPipeline = - typename GemmPipelineTypeSelector::base_pipeline; + using GemmPipeline = + typename GemmPipelineTypeSelector::pipeline; - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = - typename GemmPipelineTypeSelector::pipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(args.k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } } public: diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index a64542aa95..db51a3e8b2 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -91,12 +91,6 @@ class TestCkTileGroupedGemm : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile; - const ck_tile::index_t K_split = - (gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); @@ -176,7 +151,7 @@ class TestCkTileGroupedGemm : public ::testing::Test << blocks.z << "}" << std::endl; } - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel( Kernel{}, @@ -185,29 +160,20 @@ class TestCkTileGroupedGemm : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), gemm_descs.size())); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - std::cout << "Run without SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - std::cout << "Run using SplitK" << std::endl; - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + std::cout << "Run without SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } + else + { + std::cout << "Run using SplitK" << std::endl; + Run(ck_tile::integral_constant{}); + } } template diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index 4397668a5d..b065df6f8a 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -104,8 +104,6 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - // for testing purposes, we can hardcode the values here as we what is compatible with // pipeline using GemmUniversalTraits = @@ -121,49 +119,24 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test /*Persistent*/ false, /*NumWaveGroups*/ 1, /*Preshuffle*/ false>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = std::conditional_t< + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = std::conditional_t< Config::Pipeline_ == (PipelineType::Memory), - ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::GemmPipelineAgBgCrMem, std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV4>>; + ck_tile::GemmPipelineAgBgCrCompV3, + ck_tile::GemmPipelineAgBgCrCompV4>>; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile_; - const ck_tile::index_t K_split = - (gemm_descs[0].K + k_grain - 1) / k_grain * Config::K_Tile_; - const ck_tile::index_t num_loop = - ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = std::conditional_t< - Config::Pipeline_ == (PipelineType::Memory), - ck_tile::GemmPipelineAgBgCrMem, - std::conditional_t, - ck_tile::GemmPipelineAgBgCrCompV4>>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( Kernel{}, @@ -211,25 +184,18 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), gemm_descs.size())); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } } void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index c322aac575..0eb388082b 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -123,8 +123,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - // for testing purposes, we can hardcode the values here as we what is compatible with // pipeline using GemmUniversalTraits = @@ -140,58 +138,37 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test /*Persistent*/ false, /*NumWaveGroups*/ 1, /*Preshuffle*/ true>; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = + ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = - ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = - ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); @@ -204,7 +181,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test hipMemcpyHostToDevice, s.stream_id_)); - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel( Kernel{}, @@ -213,25 +190,18 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), gemm_descs.size())); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } } private: @@ -247,8 +217,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; - // Enable persistent mode for preshuffle using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; - - const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; - const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = - ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = + ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = - ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); @@ -327,7 +273,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test hipMemcpyHostToDevice, s.stream_id_)); - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel( Kernel{}, @@ -336,25 +282,18 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), gemm_descs.size())); - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(gemm_descs[0].k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } - }; - - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + if(gemm_descs[0].k_batch == 1) + { + Run(ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } } public: diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index d450f20105..65fede6a5f 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -337,13 +337,6 @@ class GemmKernelBuilder: "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", } - # Map pipeline names to base pipeline for hot loop detection - base_pipeline_map = { - "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", - "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", - "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", - } - # Map scheduler names to the correct enum values scheduler_type_map = { "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", @@ -423,33 +416,10 @@ struct SelectedKernel {{ // Tile partitioner using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; - - // Traits - using Traits = ck_tile::TileGemmTraits; - - // Pipeline problem - using GemmPipelineProblem = ck_tile::GemmPipelineProblem< - ADataType, - BDataType, - AccDataType, - TileShape, - Traits>; - - // Base pipeline for hot loop detection - using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}; static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ - const ck_tile::index_t k_grain = args.k_batch * TileK; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{{0}}; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto memory_operation_) {{ constexpr auto scheduler = {scheduler_type_map.get(scheduler)}; [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; @@ -462,9 +432,7 @@ struct SelectedKernel {{ ALayout, BLayout, CLayout, TransposeC, UseStructuredSparsity, UsePersistentKernel, NumWaveGroups, Preshuffle>, - scheduler, - has_hot_loop_v, - tail_number_v>; + scheduler>; using GemmPipeline = {pipeline_impl_map.get(pipeline)}; @@ -542,28 +510,23 @@ struct SelectedKernel {{ // Launch kernel constexpr int kBlockPerCu = {k_block_per_cu}; - ave_time = ck_tile::launch_kernel( + float ave_time = ck_tile::launch_kernel( stream, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); return ave_time; }}; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ - if(args.k_batch == 1) {{ - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{{}}); - }} else {{ - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{{}}); - }} - }}; + float ave_time = 0.f; + + if(args.k_batch == 1) {{ + ave_time = Run(ck_tile::integral_constant{{}}); + }} else {{ + ave_time = Run(ck_tile::integral_constant{{}}); + }} - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); return ave_time; }} }};