// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include #include #include #include #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" template float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; constexpr int kBlockPerCu = 2; // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 1; constexpr ck_tile::index_t N_Warp = 4; constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; using CodegenFlatmmShape = ck_tile::TileFlatmmShape, ck_tile::sequence, ck_tile::sequence>; using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; using CodegenFlatmmPipeline = ck_tile::FlatmmPipelineAGmemBGmemCRegV1; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::FlatmmKernel; auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } if(s.log_level_ > 0) { std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } float ave_time = ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } #include "run_flatmm_example.inc" int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }