Gemm+Reduce Fusion (#128)

* add gridwise gemm v4r1

* rename

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* use sfc in shuffling

* remove hardcode

* remove hardcode

* refactor

* fix build

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* format

* clean

* adding gemm+reduce

* adding profiler for gemm+reduce

* adding gemm+reduce profiler

* fix build

* clean up

* gemm+reduce

* fix build

* update DeviceGemm_Xdl_CShuffle; update enum to enum class

* clean up

* add test for gemm+reduce

* clean up

* refactor

* fix build

* fix build
This commit is contained in:
Chao Liu
2022-03-23 22:18:42 -05:00
committed by GitHub
parent f91579aab6
commit f95267f166
56 changed files with 4429 additions and 297 deletions

View File

@@ -5,17 +5,18 @@
#include <cstring>
int profile_gemm(int, char*[]);
int profile_batched_gemm(int, char*[]);
int profile_gemm_bias_2d(int, char*[]);
int profile_gemm_bias_relu(int, char*[]);
int profile_gemm_bias_relu_add(int, char*[]);
int profile_gemm_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]);
int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int profile_conv_bwd_data(int, char*[]);
int profile_reduce(int, char*[]);
int profile_grouped_gemm(int, char*[]);
int main(int argc, char* argv[])
{
@@ -35,10 +36,18 @@ int main(int argc, char* argv[])
{
return profile_gemm_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "gemm_reduce") == 0)
{
return profile_gemm_reduce(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm") == 0)
{
return profile_batched_gemm(argc, argv);
}
else if(strcmp(argv[1], "grouped_gemm") == 0)
{
profile_grouped_gemm(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd") == 0)
{
return profile_conv_fwd(argc, argv);
@@ -63,10 +72,6 @@ int main(int argc, char* argv[])
{
return profile_reduce(argc, argv);
}
else if(strcmp(argv[1], "grouped_gemm") == 0)
{
return profile_grouped_gemm(argc, argv);
}
else
{
// clang-format off
@@ -74,13 +79,14 @@ int main(int argc, char* argv[])
" gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
" gemm_reduce: GEMM+Reduce\n"
" grouped_gemm: Grouped Gemm\n"
" conv_fwd: ForwardConvolution\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n"
" conv_bwd: BackwardConvolution\n"
" grouped_gemm: Grouped Gemm\n"
" reduce: REDUCE\n");
" reduce: Reduce\n");
// clang-format on
return 0;