mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
* chore(copyright): update copyright header for tile_engine directory * chore(copyright): update copyright header for script directory * chore(copyright): update copyright header for test_data directory * chore(copyright): update copyright header for python directory * chore(copyright): update copyright header for profiler directory * chore(copyright): update copyright header for library directory * chore(copyright): update copyright header for include directory
48 lines
1.4 KiB
C++
48 lines
1.4 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "host_tensor.hpp"
|
|
|
|
template <typename AType,
|
|
typename BType,
|
|
typename CType,
|
|
typename AElementwiseOperation,
|
|
typename BElementwiseOperation,
|
|
typename CElementwiseOperation>
|
|
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
|
|
const Tensor<BType>& b_k_n,
|
|
Tensor<CType>& c_m_n,
|
|
const AElementwiseOperation& a_element_op,
|
|
const BElementwiseOperation& b_element_op,
|
|
const CElementwiseOperation& c_element_op)
|
|
{
|
|
auto f_mk_kn_mn = [&](auto m, auto n) {
|
|
const int K = a_m_k.mDesc.GetLengths()[1];
|
|
|
|
float v_acc = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
float v_a;
|
|
float v_b;
|
|
|
|
a_element_op(v_a, static_cast<const float>(a_m_k(m, k)));
|
|
b_element_op(v_b, static_cast<const float>(b_k_n(k, n)));
|
|
|
|
v_acc += v_a * v_b;
|
|
}
|
|
|
|
float v_c;
|
|
|
|
c_element_op(v_c, v_acc);
|
|
|
|
c_m_n(m, n) = v_c;
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_mk_kn_mn,
|
|
c_m_n.mDesc.GetLengths()[0],
|
|
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
|
|
}
|