// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include #include #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, bool splitk) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Memory friendly for Interwave scheduler constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 32; constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 4; constexpr ck_tile::index_t N_Warp = 1; constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr bool DoubleSmemBuffer = false; #endif #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) // Compute friendly for Intrawave scheduler constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr bool DoubleSmemBuffer = false; #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) // Compute friendly for Intrawave scheduler // Using the ping pong reader in the lds level constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr bool DoubleSmemBuffer = true; #endif constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; using GemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; float ave_time{0}; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; constexpr auto memory_operation = memory_operation_.value; // We create the GEMM pipeline without specifying hotloop or tailnumber. // These are automatically run inside the kernel based on the given input data. using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = GEMM_PIPELINE; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, N_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, UniversalGemmProblem::TransposeC, memory_operation>>; using Kernel = ck_tile::GroupedGemmKernel; constexpr dim3 blocks = Kernel::BlockSize(); const dim3 grids = Kernel::MaxOccupancyGridSize(s); if(s.log_level_ > 0) { std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = ck_tile::launch_kernel(s, ck_tile::make_kernel( Kernel{}, grids, blocks, 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), num_groups)); return ave_time; }; if(!splitk) { Run(ck_tile::integral_constant{}); } else { Run(ck_tile::integral_constant{}); } return ave_time; } #include "run_grouped_gemm_example.inc" constexpr bool Persistent = true; int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }