// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/epilogue.hpp" struct GemmFp16 { }; struct GemmBf16 { }; template struct GemmBasicTypeConfig; template <> struct GemmBasicTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using AccDataType = float; using CDataType = ck_tile::half_t; // ToDo: Add more bias config to support different categories of GEMM. }; template <> struct GemmBasicTypeConfig { using ADataType = ck_tile::bf16_t; using BDataType = ck_tile::bf16_t; using AccDataType = float; using CDataType = ck_tile::bf16_t; }; template struct DataTypeTraits; template <> struct DataTypeTraits { static constexpr const char* name = "fp32"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp64"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp16"; }; template <> struct DataTypeTraits { static constexpr const char* name = "bf16"; }; using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; struct gemm_traits { std::string data_type; bool is_a_rowmajor; bool is_b_rowmajor; bool is_c_rowmajor; }; template struct gemm_traits_ { using ADataType = ck_tile::remove_cvref_t; using BDataType = ck_tile::remove_cvref_t; using AccDataType = ck_tile::remove_cvref_t; using CDataType = ck_tile::remove_cvref_t; using ALayout = ck_tile::remove_cvref_t; using BLayout = ck_tile::remove_cvref_t; using CLayout = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t M_Tile = M_Tile_; static constexpr ck_tile::index_t N_Tile = N_Tile_; static constexpr ck_tile::index_t K_Tile = K_Tile_; static constexpr ck_tile::index_t M_Warp = M_Warp_; static constexpr ck_tile::index_t N_Warp = N_Warp_; static constexpr ck_tile::index_t K_Warp = K_Warp_; static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_; static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr bool kPadK = kPadK_; }; // host API template float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); float gemm(const gemm_traits& traits, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);