diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index ce8e37f46d..8eef087bf4 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -5,6 +5,117 @@ #include #include "gemm_utils.hpp" +namespace ck_tile::experimental::builder { +template +struct UniversalFactory +{ + private: + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + AlgorithmMetadata::PermuteA::value, + AlgorithmMetadata::PermuteB::value>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + AlgorithmMetadata::Pipeline::value>::template UniversalGemmPipeline; + + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + AlgorithmMetadata::Pipeline::value>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + public: + using Kernel = ck_tile::GemmKernel; + + CK_TILE_HOST static constexpr auto make_kernel(const ck_tile::GemmHostArgs& args) + { + auto kargs = Kernel::MakeKernelArgs(args); + + // NB: do we really need the stream to be launched here? + const dim3 grids = AlgorithmMetadata::KPersistent::value + ? Kernel::MaxOccupancyGridSize(ck_tile::stream_config{}) + : Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + return ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs); + } +}; +} // namespace ck_tile::experimental::builder + struct UniversalInvoker { template , - ck_tile::sequence, - ck_tile:: - sequence, - GemmConfig::PermuteA, - GemmConfig::PermuteB>; + // const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + // const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * + // GemmConfig::K_Tile; const ck_tile::index_t num_loop = + // TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = + // BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = + // BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + // const ck_tile::index_t num_loop = 64; + const bool has_hot_loop = true; + const ck_tile::TailNumber tail_num = ck_tile::TailNumber::Full; - using Traits = ck_tile::TileGemmTraits; - - using GemmUniversalTraits = - ck_tile::TileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - - using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; - - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; const auto kernel_launch_visitor = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s) - : Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) + struct Algo { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + // can't do `static constexpr` in local structs + using M_Tile = + ck_tile::integral_constant; + using N_Tile = + ck_tile::integral_constant; + using K_Tile = + ck_tile::integral_constant; + using M_Warp = + ck_tile::integral_constant; + using N_Warp = + ck_tile::integral_constant; + using K_Warp = + ck_tile::integral_constant; + using M_Warp_Tile = ck_tile::integral_constant; + using N_Warp_Tile = ck_tile::integral_constant; + using K_Warp_Tile = ck_tile::integral_constant; - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + using kPadM = + ck_tile::integral_constant; + using kPadN = + ck_tile::integral_constant; + using kPadK = + ck_tile::integral_constant; - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; + using PermuteA = ck_tile::integral_constant; + using PermuteB = ck_tile::integral_constant; + using UseStructuredSparsity = + ck_tile::integral_constant; + using KPersistent = ck_tile::integral_constant; + using Preshuffle = ck_tile::integral_constant; - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + using NumWaveGroups = + ck_tile::integral_constant; + using DoubleSmemBuffer = + ck_tile::integral_constant; + using TransposeC = ck_tile::integral_constant; + + using HasHotLoop = decltype(has_hot_loop_); + using MemoryOperation = decltype(memory_operation_); + using TailNum = decltype(tail_number_); + + using Scheduler = ck_tile::integral_constant; + using TileParitionerGroupNum = + ck_tile::integral_constant; + using TileParitionerM01 = + ck_tile::integral_constant; + using Pipeline = ck_tile::integral_constant; + + using kBlockPerCu = ck_tile::integral_constant; }; - if(s.flush_cache_) + struct Inp { - std::cout << "Flushing cache..." << std::endl; + using InputADataType = ADataType; + using InputBDataType = BDataType; + using InputDsDataType = DsDataType; + using InputCDataType = CDataType; + using InputAccDataType = AccDataType; + using InputALayout = ALayout; + using InputBLayout = BLayout; + using InputDsLayout = DsLayout; + using InputELayout = ELayout; + using InputCDEElementWise = CDEElementWise; + }; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + // if(s.log_level_ > 0) + // { + // std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + // << "shape: " << GemmShape::GetName() << '\n' + // << "problem: " << UniversalGemmProblem::GetName() << '\n' + // << "pipeline: " << GemmPipeline::GetName() << '\n' + // << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + // << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + // << "}" << std::endl; + // } - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + // Declare rotating_mem_ptr here so it stays in scope until it is needed + // std::unique_ptr> rotating_mem_ptr; + // std::function preprocess; - rotating_mem_ptr = - std::make_unique>( - kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); + // auto clear_gemm_output = [&]() { + // if(args.k_batch > 1) + // hipGetErrorString(hipMemsetAsync( + // args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + // }; - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } + // if(s.flush_cache_) + // { + // std::cout << "Flushing cache..." << std::endl; - ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + // ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + // args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + // ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + // args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + // auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + // auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + // rotating_mem_ptr = + // std::make_unique>( + // kargs.as_ptr[0], + // kargs.bs_ptr[0], + // s.rotating_count_, + // size_a_buffer, + // size_b_buffer); + // rotating_mem_ptr->Print(); + + // preprocess = [&]() { + // ck_tile::flush_icache(); + // rotating_mem_ptr->Next(); + // clear_gemm_output(); + // }; + // } + // else + // { + // preprocess = clear_gemm_output; + // } + + // ave_time = ck_tile::launch_kernel_time_mask( + // s, + // preprocess, + // ck_tile::make_kernel(Kernel{}, grids, blocks, 0, + // kargs)); + + ave_time = ck_tile::launch_kernel( + s, ck_tile::experimental::builder::UniversalFactory::make_kernel(args)); return ave_time; };