// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" template struct GemmBasicTypeConfig; template <> struct GemmBasicTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using AccDataType = float; using CDataType = ck_tile::half_t; // ToDo: Add more bias config to support different categories of GEMM. }; template struct DataTypeTraits; template <> struct DataTypeTraits { static constexpr const char* name = "fp32"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp64"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp16"; }; using Types = GemmBasicTypeConfig; // Specific type aliases for easy access using ADataType = Types::ADataType; using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; struct gemm_basic_args { const void* p_a; const void* p_b; void* p_c; ck_tile::index_t kbatch; ck_tile::index_t M; ck_tile::index_t N; ck_tile::index_t K; ck_tile::index_t stride_A; ck_tile::index_t stride_B; ck_tile::index_t stride_C; }; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("b", "1", "batch size") .insert("m", "3840", "m dimension") .insert("n", "4096", "n dimension") .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "R", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } // host API float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s);