Gemm+Reduce Fusion (#128)

* add gridwise gemm v4r1

* rename

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* use sfc in shuffling

* remove hardcode

* remove hardcode

* refactor

* fix build

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* format

* clean

* adding gemm+reduce

* adding profiler for gemm+reduce

* adding gemm+reduce profiler

* fix build

* clean up

* gemm+reduce

* fix build

* update DeviceGemm_Xdl_CShuffle; update enum to enum class

* clean up

* add test for gemm+reduce

* clean up

* refactor

* fix build

* fix build

[ROCm/composable_kernel commit: f95267f166]
This commit is contained in:
Chao Liu
2022-03-23 22:18:42 -05:00
committed by GitHub
parent 2f3f406393
commit 8cba08d07a
56 changed files with 4429 additions and 297 deletions

View File

@@ -16,7 +16,7 @@
#include "device_batched_gemm_xdl.hpp"
#include "profile_batched_gemm_impl.hpp"
enum GemmMatrixLayout
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
@@ -28,7 +28,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7
};
enum GemmDataType
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
@@ -54,8 +54,8 @@ int profile_batched_gemm(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);