Files
composable_kernel/host/driver_offline/src/gemm_driver_offline.cpp
zjing14 a651ea4f7a Fixed bfp16 host_conv_fwd (#52)
* fixed bfloat16 issues

* refactor type_convert

* fixed host_convolution_forward for ushort

Co-authored-by: Chao Liu <chao.liu2@amd.com>
2021-11-18 08:10:56 -06:00

455 lines
14 KiB
C++

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdlops_mk_kn_mn.hpp"
#include "device_gemm_xdlops_mk_nk_mn.hpp"
#include "device_gemm_xdlops_km_kn_mn.hpp"
#include "device_gemm_xdlops_km_nk_mn.hpp"
#include "device_gemm_xdlops_mk_kn_nm.hpp"
#include "device_gemm_xdlops_mk_nk_nm.hpp"
#include "device_gemm_xdlops_km_kn_nm.hpp"
#include "device_gemm_xdlops_km_nk_nm.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_NK_MN 1
#define USE_GEMM_XDL_MK_KN_NM 0
#define USE_GEMM_XDL_MK_NK_NM 0
#define USE_GEMM_XDL_KM_KN_NM 0
#define USE_GEMM_XDL_KM_NK_NM 0
enum GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
MK_KN_NM, // 4
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM // 7
};
enum GemmAlgo
{
Xdl_MK_KN_MN, // 0
Xdl_MK_NK_MN, // 1
Xdl_KM_KN_MN, // 2
Xdl_KM_NK_MN, // 3
Xdl_MK_KN_NM, // 4
Xdl_MK_NK_NM, // 5
Xdl_KM_KN_NM, // 6
Xdl_KM_NK_NM, // 7
};
template <typename AType, typename BType, typename CType>
void host_gemm(const Tensor<AType>& a,
const Tensor<BType>& b,
Tensor<CType>& c,
const GemmMatrixLayout layout)
{
if(layout == GemmMatrixLayout::MK_KN_MN)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
}
c(m, n) = v;
};
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_MN)
{
auto f_mk_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
}
c(m, n) = v;
};
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_MN)
{
auto f_km_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
}
c(m, n) = v;
};
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_MN)
{
auto f_km_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
}
c(m, n) = v;
};
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_KN_NM)
{
auto f_mk_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(k, n));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_NM)
{
auto f_mk_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(m, k)) * static_cast<const double>(b(n, k));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_NM)
{
auto f_km_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(k, n));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_NM)
{
auto f_km_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a(k, m)) * static_cast<const double>(b(n, k));
}
c(n, m) = v;
};
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}
int main(int argc, char* argv[])
{
using namespace ck;
if(argc != 12)
{
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
printf("rest: M, N, K\n");
printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n");
exit(1);
}
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
const auto algo = static_cast<GemmAlgo>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const int nrepeat = std::stoi(argv[6]);
const index_t M = std::stoi(argv[7]);
const index_t N = std::stoi(argv[8]);
const index_t K = std::stoi(argv[9]);
debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]);
debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]);
#if 0
using ab_data_t = float;
using acc_data_t = float;
using c_data_t = float;
#elif 1
using ab_data_t = half_t;
using acc_data_t = float;
using c_data_t = half_t;
#elif 1
using ab_data_t = int8_t;
using acc_data_t = int32_t;
using c_data_t = int8_t;
#endif
std::vector<std::size_t> a_lengths_host(2), b_lengths_host(2), c_lengths_host(2);
std::vector<std::size_t> a_strides_host(2), b_strides_host(2), c_strides_host(2);
// A
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN ||
layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM)
{
a_lengths_host[0] = static_cast<std::size_t>(M);
a_lengths_host[1] = static_cast<std::size_t>(K);
a_strides_host[0] = static_cast<std::size_t>(K);
a_strides_host[1] = static_cast<std::size_t>(1);
}
else
{
a_lengths_host[0] = static_cast<std::size_t>(K);
a_lengths_host[1] = static_cast<std::size_t>(M);
a_strides_host[0] = static_cast<std::size_t>(M);
a_strides_host[1] = static_cast<std::size_t>(1);
}
// B
if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN ||
layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM)
{
b_lengths_host[0] = static_cast<std::size_t>(N);
b_lengths_host[1] = static_cast<std::size_t>(K);
b_strides_host[0] = static_cast<std::size_t>(K);
b_strides_host[1] = static_cast<std::size_t>(1);
}
else
{
b_lengths_host[0] = static_cast<std::size_t>(K);
b_lengths_host[1] = static_cast<std::size_t>(N);
b_strides_host[0] = static_cast<std::size_t>(N);
b_strides_host[1] = static_cast<std::size_t>(1);
}
// C
if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN ||
layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN)
{
c_lengths_host[0] = static_cast<std::size_t>(M);
c_lengths_host[1] = static_cast<std::size_t>(N);
c_strides_host[0] = static_cast<std::size_t>(N);
c_strides_host[1] = static_cast<std::size_t>(1);
}
else
{
c_lengths_host[0] = static_cast<std::size_t>(N);
c_lengths_host[1] = static_cast<std::size_t>(M);
c_strides_host[0] = static_cast<std::size_t>(M);
c_strides_host[1] = static_cast<std::size_t>(1);
}
Tensor<ab_data_t> a(a_lengths_host, a_strides_host);
Tensor<ab_data_t> b(b_lengths_host, b_strides_host);
Tensor<c_data_t> c_host(c_lengths_host, c_strides_host);
Tensor<c_data_t> c_device(c_lengths_host, c_strides_host);
std::cout << "layout: " << layout << std::endl;
ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: ");
ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: ");
ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: ");
std::size_t num_thread = std::thread::hardware_concurrency();
switch(init_method)
{
case 0:
// no initialization
break;
case 1:
a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
break;
case 2:
a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
break;
case 3:
a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
break;
case 4:
a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
break;
default:
a.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{0.0, 1.0}, num_thread);
b.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{-0.5, 0.5}, num_thread);
}
#if USE_GEMM_XDL_MK_KN_MN
if(algo == GemmAlgo::Xdl_MK_KN_MN)
{
if(layout != GemmMatrixLayout::MK_KN_MN)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_mk_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_MK_NK_MN
if(algo == GemmAlgo::Xdl_MK_NK_MN)
{
if(layout != GemmMatrixLayout::MK_NK_MN)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_mk_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_KM_KN_MN
if(algo == GemmAlgo::Xdl_KM_KN_MN)
{
if(layout != GemmMatrixLayout::KM_KN_MN)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_km_kn_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_KM_NK_MN
if(algo == GemmAlgo::Xdl_KM_NK_MN)
{
if(layout != GemmMatrixLayout::KM_NK_MN)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_km_nk_mn<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_MK_KN_NM
if(algo == GemmAlgo::Xdl_MK_KN_NM)
{
if(layout != GemmMatrixLayout::MK_KN_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_mk_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_MK_NK_NM
if(algo == GemmAlgo::Xdl_MK_NK_NM)
{
if(layout != GemmMatrixLayout::MK_NK_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_mk_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_KM_KN_NM
if(algo == GemmAlgo::Xdl_KM_KN_NM)
{
if(layout != GemmMatrixLayout::KM_KN_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_km_kn_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
#if USE_GEMM_XDL_KM_NK_NM
if(algo == GemmAlgo::Xdl_KM_NK_NM)
{
if(layout != GemmMatrixLayout::KM_NK_NM)
{
throw std::runtime_error("wrong! layout");
}
device_gemm_xdlops_km_nk_nm<ab_data_t, acc_data_t, c_data_t>(a, b, c_device, nrepeat);
}
#endif
if(do_verification)
{
host_gemm(a, b, c_host, layout);
check_error(c_host, c_device);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_host.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_device.mData, ",") << std::endl;
}
}
}