Files
composable_kernel/host/host_tensor/include/host_gemm.hpp
Chao Liu 823657ed12 GEMM+Bias+ReLU+Add (#76)
* tweak conv for odd C

* update script

* clean up elementwise op

* fix build

* clean up

* added example for gemm+bias+relu+add

* added example for gemm+bias+relu

* add profiler for gemm_s_shuffle; re-org files

* add profiler

* fix build

* clean up

* clean up

* clean up

* fix build
2022-02-06 22:32:47 -06:00

44 lines
1.3 KiB
C++

#pragma once
#include "host_tensor.hpp"
template <typename AType,
typename BType,
typename CType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n,
Tensor<CType>& c_m_n,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
a_element_op(v_a, static_cast<const float>(a_m_k(m, k)));
b_element_op(v_b, static_cast<const float>(b_k_n(k, n)));
v_acc += v_a * v_b;
}
float v_c;
c_element_op(v_c, v_acc);
c_m_n(m, n) = v_c;
};
make_ParallelTensorFunctor(f_mk_kn_mn,
c_m_n.mDesc.GetLengths()[0],
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
}