// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include #include #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "ck_tile/ops/batched_contraction.hpp" #include "contraction_utils.hpp" template float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs& args, const ck_tile::stream_config& s) { constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr bool DoubleSmemBuffer = false; constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; constexpr bool TransposeC = false; constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; using GemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; using Problem = ck_tile::BatchedContractionProblem; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; const auto Run = [&]() { constexpr auto memory_operation = ck_tile::memory_operation_enum::set; // Always set (no atomic_add) using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = GEMM_PIPELINE; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::BatchedContractionKernel; auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(kargs); const dim3 blocks = Kernel::GetBlockSize(); if(!Kernel::IsSupportedArguments(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); } if(s.log_level_ > 0) { std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' << "shape: " << GemmShape::GetName() << '\n' << "problem: " << GemmPipelineProblem::GetName() << '\n' << "pipeline: " << GemmPipeline::GetName() << '\n' << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); return ck_tile::launch_kernel(s, kernel); }; return Run(); } #define HANDLE_CASE(G, M, N, K) \ if(num_g_dims == G && num_m_dims == M && num_n_dims == N && num_k_dims == K) \ { \ return batched_contraction_impl(args, s); \ } template float batched_contraction(const ck_tile::BatchedContractionHostArgs& args, const ck_tile::stream_config& s, ck_tile::index_t num_g_dims, ck_tile::index_t num_m_dims, ck_tile::index_t num_n_dims, ck_tile::index_t num_k_dims) { std::cout << "Dimensions: G=" << num_g_dims << ", M=" << num_m_dims << ", N=" << num_n_dims << ", K=" << num_k_dims << std::endl; HANDLE_CASE(1, 1, 1, 1); HANDLE_CASE(2, 1, 1, 1); HANDLE_CASE(2, 2, 2, 1); HANDLE_CASE(1, 2, 1, 1); HANDLE_CASE(2, 2, 2, 2); throw std::runtime_error( "Unsupported dimension combination: G=" + std::to_string(num_g_dims) + ", M=" + std::to_string(num_m_dims) + ", N=" + std::to_string(num_n_dims) + ", K=" + std::to_string(num_k_dims) + ". Please add this combination to the kernel."); } #include "run_batched_contraction_example.inc" int main(int argc, char* argv[]) { try { return !run_batched_contraction_example(argc, argv); } catch(const std::runtime_error& e) { std::cerr << "Runtime error: " << e.what() << '\n'; return EXIT_FAILURE; } }