Gemm+Reduce Fusion (#128)

* add gridwise gemm v4r1

* rename

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* use sfc in shuffling

* remove hardcode

* remove hardcode

* refactor

* fix build

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* format

* clean

* adding gemm+reduce

* adding profiler for gemm+reduce

* adding gemm+reduce profiler

* fix build

* clean up

* gemm+reduce

* fix build

* update DeviceGemm_Xdl_CShuffle; update enum to enum class

* clean up

* add test for gemm+reduce

* clean up

* refactor

* fix build

* fix build
This commit is contained in:
Chao Liu
2022-03-23 22:18:42 -05:00
committed by GitHub
parent f91579aab6
commit f95267f166
56 changed files with 4429 additions and 297 deletions

View File

@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_gemm_impl.hpp"
enum GemmMatrixLayout
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7
};
enum GemmDataType
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
@@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);