// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include #include #include #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, ck_tile:: sequence>; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; constexpr auto scheduler = GemmConfig::Scheduler; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Kernel arguments not supported!"); } const dim3 blocks = Kernel::BlockSize(); const dim3 grids = Kernel::GridSize(gemm_descs); HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, kargs.data(), get_workspace_size(gemm_descs), hipMemcpyHostToDevice, s.stream_id_)); if(s.log_level_ > 0) { std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } return ck_tile::launch_kernel( s, ck_tile::make_kernel( Kernel{}, grids, blocks, 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), gemm_descs.size())); }; if(gemm_descs[0].k_batch == 1) { return Run(ck_tile::integral_constant{}); } else { return Run(ck_tile::integral_constant{}); } } template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, bool splitk) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, ck_tile:: sequence>; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; float ave_time{0}; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; // We create the GEMM pipeline without specifying hotloop or tailnumber. // These are automatically run inside the kernel based on the given input data. using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC, memory_operation>>; using Kernel = ck_tile::GroupedGemmKernel; const dim3 blocks = Kernel::BlockSize(); const dim3 grids = Kernel::MaxOccupancyGridSize(s); if(s.log_level_ > 0) { std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } return ave_time = ck_tile::launch_kernel( s, ck_tile::make_kernel( Kernel{}, grids, blocks, 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), num_groups)); }; if(!splitk) { return ave_time = Run(ck_tile::integral_constant{}); } else { return ave_time = Run(ck_tile::integral_constant{}); } } #include "run_grouped_gemm_example.inc" template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Types = GemmTypeConfig; // Specific type aliases for easy access using ADataType = typename Types::ADataType; using BDataType = typename Types::BDataType; using AccDataType = typename Types::AccDataType; using CDataType = typename Types::CDataType; if(a_layout == "R" && b_layout == "C") { return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "R" && b_layout == "R") { return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); } else { throw std::runtime_error("Unsupported data layout configuration for A and B tensors!"); } } template