// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include #include #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "batched_gemm.hpp" template float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; constexpr bool TransposeC = false; constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; using GemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; constexpr auto scheduler = GemmConfig::Scheduler; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } if(s.log_level_ > 0) { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' << "pipeline: " << GemmPipeline::GetName() << '\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } return ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; if(args.k_batch == 1) { return Run(ck_tile::integral_constant{}); } else { return Run(ck_tile::integral_constant{}); } } #include "run_batched_gemm_example.inc" int main(int argc, char* argv[]) { try { #if CK_TILE_USE_WMMA return !run_batched_gemm_example(argc, argv); #else return !run_batched_gemm_example(argc, argv); #endif } catch(const std::runtime_error& e) { std::cerr << "Runtime error: " << e.what() << '\n'; return EXIT_FAILURE; } }