mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
* Checkpoint: Finished with the tile example & kernel verification, working on the different matrix layout * Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future. * Fix: Clang Format, API fixed from fmha * fix with better naming convention * revert back the pipeline code of fmha * Fixed: Addressed the comments and merge the GEMM shape of GEMM Operator and FMHA Operator to one. * clang format with the reference_gemm file * convert the clang format with the remod.py * Changed the format and variable name of the kernel gemm_shape and partitioner --------- Co-authored-by: thomasning <thomasning@banff-cyxtera-s70-4.ctr.dcgpu>
72 lines
1.6 KiB
C++
72 lines
1.6 KiB
C++
|
|
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/ops/epilogue.hpp"
|
|
#include "ck_tile/ops/gemm.hpp"
|
|
#include <string>
|
|
|
|
template <typename DataType>
|
|
struct GemmBasicTypeConfig;
|
|
|
|
template <>
|
|
struct GemmBasicTypeConfig<ck_tile::half_t>
|
|
{
|
|
using ADataType = ck_tile::half_t;
|
|
using BDataType = ck_tile::half_t;
|
|
using AccDataType = float;
|
|
using CDataType = ck_tile::half_t; // type convert
|
|
// ToDo: Add more bias config to support different categories of GEMM.
|
|
};
|
|
|
|
template <typename T>
|
|
struct DataTypeTraits;
|
|
|
|
template <>
|
|
struct DataTypeTraits<float>
|
|
{
|
|
static constexpr const char* name = "fp32";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<double>
|
|
{
|
|
static constexpr const char* name = "fp64";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<ck_tile::half_t>
|
|
{
|
|
static constexpr const char* name = "fp16";
|
|
};
|
|
|
|
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
|
|
|
|
// Specific type aliases for easy access
|
|
using ADataType = Types::ADataType;
|
|
using BDataType = Types::BDataType;
|
|
using AccDataType = Types::AccDataType;
|
|
using CDataType = Types::CDataType;
|
|
|
|
struct gemm_basic_args
|
|
{
|
|
const void* p_a;
|
|
const void* p_b;
|
|
void* p_c;
|
|
float epsilon;
|
|
ck_tile::index_t kbatch;
|
|
ck_tile::index_t M;
|
|
ck_tile::index_t N;
|
|
ck_tile::index_t K;
|
|
ck_tile::index_t stride_A;
|
|
ck_tile::index_t stride_B;
|
|
ck_tile::index_t stride_C;
|
|
};
|
|
|
|
// host API
|
|
float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s);
|