// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #ifndef CK_TILE_PIPELINE_DEFAULT #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 #endif #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave #else #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif template struct GemmTypeConfig; template <> struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using CDataType = ck_tile::half_t; using AccDataType = float; }; using Types = GemmTypeConfig; // Specific type aliases for easy access using ADataType = Types::ADataType; using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; using grouped_gemm_kargs = ck_tile::GemmHostArgs; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("Ms", "", "M dimensions - empty by default.") .insert("Ns", "", "N dimensions - empty by default.") .insert("Ks", "", "K dimensions - empty by default.") .insert("stride_As", "", "Tensor A strides - it is empty by default.") .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") .insert("a_layout", "R", "A tensor data layout - Row by default.") .insert("b_layout", "C", "B tensor data layout - Row by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } inline std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr); template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, bool splitk = false);