// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" template struct GemmBasicTypeConfig; template <> struct GemmBasicTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using CDataType = ck_tile::half_t; using AccDataType = float; }; using Types = GemmBasicTypeConfig; // Specific type aliases for easy access using ADataType = Types::ADataType; using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("Ms", "", "M dimensions - empty by default.") .insert("Ns", "", "N dimensions - empty by default.") .insert("Ks", "", "K dimensions - empty by default.") .insert("stride_As", "", "Tensor A strides - it is empty by default.") .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") .insert("a_layout", "R", "A tensor data layout - Row by default.") .insert("b_layout", "C", "B tensor data layout - Row by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "16", "group count."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } std::size_t get_workspace_size(const std::vector& gemm_descs); float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* p_workspace_);