mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Tweak GEMM kernel (#38)
* add parameters
* tweak gemm
* tweak
* update conv
* update script
* adding bwd 1x1
* update script
* adding 1x1 bwd
* debugging bwd 1x1 failure
* update script
* update script
* test
* test v100
* clean up
[ROCm/composable_kernel commit: b3e8d57d51]
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
#define DEVICE_HPP
|
||||
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
|
||||
@@ -74,6 +76,8 @@ float launch_and_time_kernel(
|
||||
|
||||
timer.End();
|
||||
|
||||
// std::this_thread::sleep_for (std::chrono::microseconds(10));
|
||||
|
||||
return timer.GetElapsedTime() / nrepeat;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,10 @@ enum GemmMatrixLayout
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
MK_KN_NM, // 4
|
||||
MK_NK_NM, // 5
|
||||
KM_KN_NM, // 6
|
||||
KM_NK_NM, // 7
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -80,6 +80,78 @@ void host_gemm(const Tensor<AType>& a,
|
||||
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_KN_NM)
|
||||
{
|
||||
auto f_mk_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
auto f_mk_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_NM)
|
||||
{
|
||||
auto f_km_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
auto f_km_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
|
||||
Reference in New Issue
Block a user