Files
composable_kernel/tutorial/ck_tile/gemm/reference_gemm.hpp
AviralGoelAMD adb8f67b4f feat: add new optimized tutorial kernels
- Add 01_naive_gemm baseline implementation
- Add 02_padding_k_first with PADDING_K_FIRST + MFMA_32x32x16
- Add 03_mfma_16x16x16 with PADDING_K_FIRST + MFMA_16x16x16
- Share common reference_gemm.hpp in parent gemm/ directory
2026-01-29 12:45:18 +00:00

38 lines
1.2 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_basic_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
const ck_tile::HostTensor<BDataType>& b_n_k,
ck_tile::HostTensor<CDataType>& c_m_n)
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1];
auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_n_k(n, k);
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_acc);
}
};
ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(
std::thread::hardware_concurrency());
}