// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include #include #include #include #include "ck_tile/host.hpp" #include "gemm_utils.hpp" #include "run_gemm_example.inc" template float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, ck_tile:: sequence, GemmConfig::PermuteA, GemmConfig::PermuteB>; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; using BaseGemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template UniversalGemmPipeline; const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); dim3 grids; if constexpr(Persistent) { grids = Kernel::MaxOccupancyGridSize(s); } else { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } if(s.log_level_ > 0) { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' << "problem: " << UniversalGemmProblem::GetName() << '\n' << "pipeline: " << GemmPipeline::GetName() << '\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } if(s.flush_cache_) { std::cout << "Flushing cache..." << std::endl; ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( args.K, args.N, args.stride_B, is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); ck_tile::RotatingMemWrapper rotating_mem( kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); rotating_mem.Print(); auto run_flush_cache = [&]() { // flush icache ck_tile::flush_icache(); // rotating mem rotating_mem.Next(); // clear c mem if(args.k_batch > 1) hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, ck_tile::make_kernel( Kernel{}, grids, blocks, 0, kargs)); } else { ave_time = ck_tile::launch_kernel(s, ck_tile::make_kernel( Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { Run(has_hot_loop_, tail_number_, ck_tile::integral_constant{}); } else { Run(has_hot_loop_, tail_number_, ck_tile::integral_constant{}); } }; BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); return ave_time; } template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; auto [result, arg_parser] = create_args(argc, argv); bool preshuffle = GemmConfig::Preshuffle; if(preshuffle && std::is_same_v) { throw std::runtime_error("Preshuffle is not supported for this int4 datatype!"); } if(preshuffle && a_layout != "R" && b_layout != "C") { throw std::runtime_error( "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); } if constexpr(std::is_same_v) { if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else { throw std::runtime_error("Unsupported memory layout for the input matrices when " "BPrecType is ck_tile::pk_int4_t!"); } } else { if(a_layout == "R" && b_layout == "R") { return run_gemm_example_with_layouts( argc, argv, Row{}, Row{}, Row{}); } else if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { return run_gemm_example_with_layouts( argc, argv, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else { throw std::runtime_error("Unsupported memory layout for the input matrices!"); } } } template