mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
- Add 01_naive_gemm baseline implementation - Add 02_padding_k_first with PADDING_K_FIRST + MFMA_32x32x16 - Add 03_mfma_16x16x16 with PADDING_K_FIRST + MFMA_16x16x16 - Share common reference_gemm.hpp in parent gemm/ directory
38 lines
1.2 KiB
C++
38 lines
1.2 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/host_tensor.hpp"
|
|
|
|
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
|
void reference_basic_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
|
|
const ck_tile::HostTensor<BDataType>& b_n_k,
|
|
ck_tile::HostTensor<CDataType>& c_m_n)
|
|
{
|
|
const int N = b_n_k.mDesc.get_lengths()[0];
|
|
const int K = b_n_k.mDesc.get_lengths()[1];
|
|
|
|
auto f = [&](auto m) {
|
|
for(int n = 0; n < N; ++n)
|
|
{
|
|
AccDataType v_acc = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
ADataType v_a = a_m_k(m, k);
|
|
BDataType v_b = b_n_k(n, k);
|
|
|
|
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
|
ck_tile::type_convert<AccDataType>(v_b);
|
|
}
|
|
|
|
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_acc);
|
|
}
|
|
};
|
|
|
|
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(
|
|
std::thread::hardware_concurrency());
|
|
}
|