mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
* init StaticBufferV2 * clean * adopt old output stage for staticBufferV2 * clean * remove hack * clean * clean * add parameters * clean code * move c_buffer alloc into blockwise gemm * add adaptors for m/n_thread_data_on_grid * tweak gemm * adjust blockwise_gemm_xdlops * tweak * update conv * update script * adding bwd 1x1 * update script * adding 1x1 bwd * debugging bwd 1x1 failure * update script * update script * test * test v100 * add bf16_1k * clang-format * clean * add bfp16 for gfx908 * add verification * clean up * clean code * restore bfl16 * clean * add bfp16 support into gemm_driver * apply new generator to other drivers * add int8 support * cleanb * clean * clean * clean Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Chao Liu <lc.roy86@gmail.com> Co-authored-by: root <root@hayabusa6111.amd.com>
182 lines
5.3 KiB
C++
182 lines
5.3 KiB
C++
#pragma once
|
|
#include "host_tensor.hpp"
|
|
|
|
template <>
|
|
void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
|
|
const Tensor<ushort>& b,
|
|
Tensor<ushort>& c,
|
|
const GemmMatrixLayout layout)
|
|
{
|
|
if(layout == GemmMatrixLayout::MK_KN_MN)
|
|
{
|
|
auto f_mk_kn_mn = [&](auto m, auto n) {
|
|
const int K = a.mDesc.GetLengths()[1];
|
|
|
|
double v = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n));
|
|
}
|
|
|
|
c(m, n) = float_to_bfloat16(v);
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
|
std::thread::hardware_concurrency());
|
|
}
|
|
else if(layout == GemmMatrixLayout::MK_NK_MN)
|
|
{
|
|
auto f_mk_nk_mn = [&](auto m, auto n) {
|
|
const int K = a.mDesc.GetLengths()[1];
|
|
|
|
double v = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k));
|
|
}
|
|
|
|
c(m, n) = float_to_bfloat16(v);
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
|
std::thread::hardware_concurrency());
|
|
}
|
|
else if(layout == GemmMatrixLayout::KM_KN_MN)
|
|
{
|
|
auto f_km_kn_mn = [&](auto m, auto n) {
|
|
const int K = a.mDesc.GetLengths()[0];
|
|
|
|
double v = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n));
|
|
}
|
|
|
|
c(m, n) = float_to_bfloat16(v);
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
|
std::thread::hardware_concurrency());
|
|
}
|
|
else if(layout == GemmMatrixLayout::KM_NK_MN)
|
|
{
|
|
auto f_km_nk_mn = [&](auto m, auto n) {
|
|
const int K = a.mDesc.GetLengths()[0];
|
|
|
|
double v = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k));
|
|
}
|
|
|
|
c(m, n) = float_to_bfloat16(v);
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
|
std::thread::hardware_concurrency());
|
|
}
|
|
else if(layout == GemmMatrixLayout::MK_KN_NM)
|
|
{
|
|
auto f_mk_kn_nm = [&](auto n, auto m) {
|
|
const int K = a.mDesc.GetLengths()[1];
|
|
|
|
double v = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n));
|
|
}
|
|
|
|
c(n, m) = float_to_bfloat16(v);
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
|
std::thread::hardware_concurrency());
|
|
}
|
|
else if(layout == GemmMatrixLayout::MK_NK_NM)
|
|
{
|
|
auto f_mk_nk_nm = [&](auto n, auto m) {
|
|
const int K = a.mDesc.GetLengths()[1];
|
|
|
|
double v = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k));
|
|
}
|
|
|
|
c(n, m) = float_to_bfloat16(v);
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
|
std::thread::hardware_concurrency());
|
|
}
|
|
else if(layout == GemmMatrixLayout::KM_KN_NM)
|
|
{
|
|
auto f_km_kn_nm = [&](auto n, auto m) {
|
|
const int K = a.mDesc.GetLengths()[0];
|
|
|
|
double v = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n));
|
|
}
|
|
|
|
c(n, m) = float_to_bfloat16(v);
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
|
std::thread::hardware_concurrency());
|
|
}
|
|
else if(layout == GemmMatrixLayout::KM_NK_NM)
|
|
{
|
|
auto f_km_nk_nm = [&](auto n, auto m) {
|
|
const int K = a.mDesc.GetLengths()[0];
|
|
|
|
double v = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k));
|
|
}
|
|
|
|
c(n, m) = float_to_bfloat16(v);
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
|
std::thread::hardware_concurrency());
|
|
}
|
|
else
|
|
{
|
|
throw std::runtime_error("wrong! not supported layout");
|
|
}
|
|
}
|
|
|
|
template <typename AType, typename BType, typename CType>
|
|
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
|
|
const Tensor<BType>& b_k_n,
|
|
Tensor<CType>& c_m_n)
|
|
{
|
|
auto f_mk_kn_mn = [&](auto m, auto n) {
|
|
const int K = a_m_k.mDesc.GetLengths()[1];
|
|
|
|
double v = 0;
|
|
|
|
for(int k = 0; k < K; ++k)
|
|
{
|
|
v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n));
|
|
}
|
|
|
|
c_m_n(m, n) = v;
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_mk_kn_mn,
|
|
c_m_n.mDesc.GetLengths()[0],
|
|
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
|
|
}
|