Files
composable_kernel/example/ck_tile/03_gemm/gemm_basic.hpp
2025-01-12 15:11:04 +00:00

130 lines
3.7 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
struct GemmFp16
{
};
struct GemmBf16
{
};
template <typename DataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<GemmFp16>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmBasicTypeConfig<GemmBf16>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
struct gemm_traits
{
std::string data_type;
bool is_a_rowmajor;
bool is_b_rowmajor;
bool is_c_rowmajor;
};
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t M_Tile_,
ck_tile::index_t N_Tile_,
ck_tile::index_t K_Tile_,
ck_tile::index_t M_Warp_,
ck_tile::index_t N_Warp_,
ck_tile::index_t K_Warp_,
ck_tile::index_t M_Warp_Tile_,
ck_tile::index_t N_Warp_Tile_,
ck_tile::index_t K_Warp_Tile_,
bool kPadM_,
bool kPadN_,
bool kPadK_>
struct gemm_traits_
{
using ADataType = ck_tile::remove_cvref_t<ADataType_>;
using BDataType = ck_tile::remove_cvref_t<BDataType_>;
using AccDataType = ck_tile::remove_cvref_t<AccDataType_>;
using CDataType = ck_tile::remove_cvref_t<CDataType_>;
using ALayout = ck_tile::remove_cvref_t<ALayout_>;
using BLayout = ck_tile::remove_cvref_t<BLayout_>;
using CLayout = ck_tile::remove_cvref_t<CLayout_>;
static constexpr ck_tile::index_t M_Tile = M_Tile_;
static constexpr ck_tile::index_t N_Tile = N_Tile_;
static constexpr ck_tile::index_t K_Tile = K_Tile_;
static constexpr ck_tile::index_t M_Warp = M_Warp_;
static constexpr ck_tile::index_t N_Warp = N_Warp_;
static constexpr ck_tile::index_t K_Warp = K_Warp_;
static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_;
static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_;
static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
};
// host API
template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
float gemm(const gemm_traits& traits,
const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s);