mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Gemm alpha beta profiler (fp32 & fp16) (#91)
* [What] Refactor verification of gemm alpha_beta, move to reference operation
[Why] Sync with other verification
* Profile mk_nk for gemm bias 2d
* Support bias 2d with mn * kn in profiler
* Support bias 2d with km*kn and km*nk in profiler
* Support fp32 bias 2d in profiler
* format
* format
Co-authored-by: rocking <chunylai@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 19c5d6e651]
This commit is contained in:
133
reference_operation/include/reference_gemm_bias_2d.hpp
Normal file
133
reference_operation/include/reference_gemm_bias_2d.hpp
Normal file
@@ -0,0 +1,133 @@
|
||||
#ifndef REFERENCE_GEMM_BIAS_BIAS_2D_HPP
|
||||
#define REFERENCE_GEMM_BIAS_BIAS_2D_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename C0DataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct ReferenceGemmBias2D : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
const Tensor<C0DataType>& c0_m_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: a_m_k_{a_m_k},
|
||||
b_k_n_{b_k_n},
|
||||
c0_m_n_{c0_m_n},
|
||||
c_m_n_{c_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_m_k_;
|
||||
const Tensor<BDataType>& b_k_n_;
|
||||
const Tensor<CDataType>& c0_m_n_;
|
||||
Tensor<CDataType>& c_m_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceGemmBias2D::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
|
||||
|
||||
AccDataType a = 0;
|
||||
AccDataType b = 0;
|
||||
AccDataType acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
arg.a_element_op_(a, arg.a_m_k_(m, k));
|
||||
arg.b_element_op_(b, arg.b_k_n_(k, n));
|
||||
acc += a * b;
|
||||
}
|
||||
|
||||
CDataType cast_acc = static_cast<CDataType>(acc);
|
||||
arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(
|
||||
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg, int) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_k_n,
|
||||
const Tensor<C0DataType>& c0_m_n,
|
||||
Tensor<CDataType>& c_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceGemmBias2D"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user