// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include #include #include #include #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" #include "utils.hpp" namespace { struct GroupedGemmKernelParam { static const bool kPadM = false; static const bool kPadN = false; static const bool kPadK = false; static const bool kTilePermute = false; static const ck_tile::index_t kOutputRank = 2; static const int kBlockPerCu = 1; static const ck_tile::index_t M_Tile = 128; static const ck_tile::index_t N_Tile = 128; static const ck_tile::index_t K_Tile = 32; static const ck_tile::index_t M_Warp = 2; static const ck_tile::index_t N_Warp = 2; static const ck_tile::index_t K_Warp = 1; static const ck_tile::index_t M_Warp_Tile = 32; static const ck_tile::index_t N_Warp_Tile = 32; static const ck_tile::index_t K_Warp_Tile = 8; }; using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; using TilePartitioner = ck_tile::GemmTile1DPartitioner; template using GemmEpilogue = std::conditional_t< std::is_same_v, ck_tile::CShuffleEpilogue>, ck_tile::Default2DEpilogue>>; template using CodegenGemmTraits = ck_tile::TileGemmTraits; template using CodegenPipelineProblem = ck_tile::GemmPipelineProblem>; using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; template using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1, CodegenGemmPolicy>; template using Kernel = ck_tile::GroupedGemmKernel, GemmEpilogue>; }; // namespace std::size_t GetWorkspaceSize(const std::vector& gemm_descs) { return ::Kernel::GetWorkSpaceSize(gemm_descs); } template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* p_workspace_) { using GroupedGemmKernel = ::Kernel; auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs); const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs); constexpr dim3 blocks = GroupedGemmKernel::BlockSize(); ck_tile::hip_check_error(hipMemcpyWithStream( p_workspace_, arguments.data(), arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg), hipMemcpyHostToDevice, s.stream_id_)); if(s.log_level_ > 0) { std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } float ave_time = ck_tile::launch_kernel(s, ck_tile::make_kernel( GroupedGemmKernel{}, grids, blocks, 0, ck_tile::cast_pointer_to_constant_address_space(p_workspace_), gemm_descs.size())); return ave_time; } #include "run_grouped_gemm_example.inc" int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }