mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
* tweak conv for odd C * update script * clean up elementwise op * fix build * clean up * added example for gemm+bias+relu+add * added example for gemm+bias+relu * add profiler for gemm_s_shuffle; re-org files * add profiler * fix build * clean up * clean up * clean up * fix build
145 lines
4.3 KiB
C++
145 lines
4.3 KiB
C++
#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP
|
|
#define REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP
|
|
|
|
#include <iostream>
|
|
#include <sstream>
|
|
#include "device_base.hpp"
|
|
#include "host_tensor.hpp"
|
|
|
|
namespace ck {
|
|
namespace tensor_operation {
|
|
namespace host {
|
|
|
|
template <typename ADataType,
|
|
typename BDataType,
|
|
typename CDataType,
|
|
typename AElementwiseOperation,
|
|
typename BElementwiseOperation,
|
|
typename CElementwiseOperation>
|
|
struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
|
|
{
|
|
// Argument
|
|
struct Argument : public device::BaseArgument
|
|
{
|
|
Argument(const Tensor<ADataType>& a_m_k,
|
|
const Tensor<BDataType>& b_k_n,
|
|
Tensor<CDataType>& c_m_n,
|
|
const Tensor<CDataType>& c0_n,
|
|
const Tensor<CDataType>& c1_m_n,
|
|
AElementwiseOperation a_element_op,
|
|
BElementwiseOperation b_element_op,
|
|
CElementwiseOperation c_element_op)
|
|
: a_m_k_{a_m_k},
|
|
b_k_n_{b_k_n},
|
|
c_m_n_{c_m_n},
|
|
c0_n_{c0_n},
|
|
c1_m_n_{c1_m_n},
|
|
a_element_op_{a_element_op},
|
|
b_element_op_{b_element_op},
|
|
c_element_op_{c_element_op}
|
|
{
|
|
}
|
|
|
|
const Tensor<ADataType>& a_m_k_;
|
|
const Tensor<BDataType>& b_k_n_;
|
|
Tensor<CDataType>& c_m_n_;
|
|
const Tensor<CDataType>& c0_n_;
|
|
const Tensor<CDataType>& c1_m_n_;
|
|
|
|
AElementwiseOperation a_element_op_;
|
|
BElementwiseOperation b_element_op_;
|
|
CElementwiseOperation c_element_op_;
|
|
};
|
|
|
|
// Invoker
|
|
struct Invoker : public device::BaseInvoker
|
|
{
|
|
using Argument = ReferenceGemmBiasActivationAdd::Argument;
|
|
|
|
float Run(const Argument& arg)
|
|
{
|
|
auto f_mk_kn_mn = [&](auto m, auto n) {
|
|
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
|
|
|
|
float v_acc = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
float v_a;
|
|
float v_b;
|
|
|
|
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
|
|
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
|
|
|
|
v_acc += v_a * v_b;
|
|
}
|
|
|
|
float v_c;
|
|
|
|
arg.c_element_op_(v_c,
|
|
v_acc,
|
|
static_cast<float>(arg.c0_n_(n)),
|
|
static_cast<float>(arg.c1_m_n_(m, n)));
|
|
|
|
arg.c_m_n_(m, n) = v_c;
|
|
};
|
|
|
|
make_ParallelTensorFunctor(
|
|
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
|
|
std::thread::hardware_concurrency());
|
|
|
|
return 0;
|
|
}
|
|
|
|
float Run(const device::BaseArgument* p_arg, int) override
|
|
{
|
|
return Run(*dynamic_cast<const Argument*>(p_arg));
|
|
}
|
|
};
|
|
|
|
static constexpr bool IsValidCompilationParameter()
|
|
{
|
|
// TODO: properly implement this check
|
|
return true;
|
|
}
|
|
|
|
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
|
|
|
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
|
|
const Tensor<BDataType>& b_k_n,
|
|
Tensor<CDataType>& c_m_n,
|
|
const Tensor<CDataType>& c0_n,
|
|
const Tensor<CDataType>& c1_m_n,
|
|
AElementwiseOperation a_element_op,
|
|
BElementwiseOperation b_element_op,
|
|
CElementwiseOperation c_element_op)
|
|
{
|
|
return Argument{
|
|
a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op};
|
|
}
|
|
|
|
static auto MakeInvoker() { return Invoker{}; }
|
|
|
|
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
|
{
|
|
return std::make_unique<Invoker>(Invoker{});
|
|
}
|
|
|
|
std::string GetTypeString() const override
|
|
{
|
|
auto str = std::stringstream();
|
|
|
|
// clang-format off
|
|
str << "ReferenceGemmBiasActivationAdd"
|
|
<< std::endl;
|
|
// clang-format on
|
|
|
|
return str.str();
|
|
}
|
|
};
|
|
|
|
} // namespace host
|
|
} // namespace tensor_operation
|
|
} // namespace ck
|
|
#endif
|