mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
* delete obselete files
* move files
* build
* update cmake
* update cmake
* fix build
* reorg examples
* update cmake for example and test
[ROCm/composable_kernel commit: 5d37d7bff4]
104 lines
3.5 KiB
C++
104 lines
3.5 KiB
C++
#ifndef GEMM_UTILS_HPP
|
|
#define GEMM_UTILS_HPP
|
|
|
|
#include "config.hpp"
|
|
#include "device.hpp"
|
|
#include "host_tensor.hpp"
|
|
|
|
namespace ck {
|
|
namespace gemm_util {
|
|
|
|
struct GemmParams
|
|
{
|
|
GemmParams()
|
|
: M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0)
|
|
{
|
|
}
|
|
|
|
ck::index_t M;
|
|
ck::index_t N;
|
|
ck::index_t K;
|
|
|
|
ck::index_t StrideA;
|
|
ck::index_t StrideB;
|
|
ck::index_t StrideC;
|
|
|
|
float alpha;
|
|
float beta;
|
|
};
|
|
|
|
template <typename GemmInstance,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename CDataType,
|
|
typename AElementwiseOperation,
|
|
typename BElementwiseOperation,
|
|
typename CElementwiseOperation>
|
|
void RunHostGEMM(const Tensor<ADataType>& A,
|
|
const Tensor<BDataType>& B,
|
|
Tensor<CDataType>& C,
|
|
AElementwiseOperation a_element_op,
|
|
BElementwiseOperation b_element_op,
|
|
CElementwiseOperation c_element_op)
|
|
{
|
|
auto ref_gemm = GemmInstance{};
|
|
auto ref_invoker = ref_gemm.MakeInvoker();
|
|
|
|
auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
|
|
|
|
ref_invoker.Run(ref_argument);
|
|
}
|
|
|
|
template <typename DeviceGemmPtr_,
|
|
typename ADataType,
|
|
typename BDataType,
|
|
typename CDataType,
|
|
typename AElementwiseOperation,
|
|
typename BElementwiseOperation,
|
|
typename CElementwiseOperation>
|
|
void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
|
|
const ck::gemm_util::GemmParams& params,
|
|
const Tensor<ADataType>& A,
|
|
const Tensor<BDataType>& B,
|
|
Tensor<CDataType>& C,
|
|
AElementwiseOperation a_element_op,
|
|
BElementwiseOperation b_element_op,
|
|
CElementwiseOperation c_element_op)
|
|
{
|
|
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace());
|
|
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
|
|
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
|
|
|
|
a_m_k_device_buf.ToDevice(A.mData.data());
|
|
b_k_n_device_buf.ToDevice(B.mData.data());
|
|
|
|
auto invoker_ptr = gemmPtr->MakeInvokerPointer();
|
|
auto argument_ptr =
|
|
gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
|
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
|
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
|
params.M,
|
|
params.N,
|
|
params.K,
|
|
params.StrideA,
|
|
params.StrideB,
|
|
params.StrideC,
|
|
a_element_op,
|
|
b_element_op,
|
|
c_element_op);
|
|
|
|
if(!gemmPtr->IsSupportedArgument(argument_ptr.get()))
|
|
{
|
|
throw std::runtime_error(
|
|
"wrong! device_gemm with the specified compilation parameters does "
|
|
"not support this GEMM problem");
|
|
}
|
|
|
|
invoker_ptr->Run(argument_ptr.get());
|
|
c_m_n_device_buf.FromDevice(C.mData.data());
|
|
}
|
|
|
|
} // namespace gemm_util
|
|
} // namespace ck
|
|
#endif
|