Tweak GEMM kernel (#38)

* add parameters

* tweak gemm

* tweak

* update conv

* update script

* adding bwd 1x1

* update script

* adding 1x1 bwd

* debugging bwd 1x1 failure

* update script

* update script

* test

* test v100

* clean up

[ROCm/composable_kernel commit: b3e8d57d51]
This commit is contained in:
Chao Liu
2021-10-06 11:12:36 -05:00
committed by GitHub
parent 8159394bfa
commit 720cf3d6b2
28 changed files with 3642 additions and 529 deletions

View File

@@ -2,6 +2,8 @@
#define DEVICE_HPP
#include <memory>
#include <thread>
#include <chrono>
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
@@ -74,6 +76,8 @@ float launch_and_time_kernel(
timer.End();
// std::this_thread::sleep_for (std::chrono::microseconds(10));
return timer.GetElapsedTime() / nrepeat;
}

View File

@@ -7,6 +7,10 @@ enum GemmMatrixLayout
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
MK_KN_NM, // 4
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
#endif

View File

@@ -80,6 +80,78 @@ void host_gemm(const Tensor<AType>& a,
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_KN_NM)
{
auto f_mk_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_NM)
{
auto f_mk_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_NM)
{
auto f_km_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_NM)
{
auto f_km_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");