mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
GEMM+Bias+ReLU+Add (#76)
* tweak conv for odd C
* update script
* clean up elementwise op
* fix build
* clean up
* added example for gemm+bias+relu+add
* added example for gemm+bias+relu
* add profiler for gemm_s_shuffle; re-org files
* add profiler
* fix build
* clean up
* clean up
* clean up
* fix build
[ROCm/composable_kernel commit: 823657ed12]
This commit is contained in:
@@ -17,15 +17,24 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a_m_k.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
float v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
|
||||
static_cast<const double>(b_element_op(b_k_n(k, n)));
|
||||
float v_a;
|
||||
float v_b;
|
||||
|
||||
a_element_op(v_a, static_cast<const float>(a_m_k(m, k)));
|
||||
b_element_op(v_b, static_cast<const float>(b_k_n(k, n)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
c_m_n(m, n) = c_element_op(v);
|
||||
float v_c;
|
||||
|
||||
c_element_op(v_c, v_acc);
|
||||
|
||||
c_m_n(m, n) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn,
|
||||
|
||||
Reference in New Issue
Block a user