mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Gemm+Reduce Fusion (#128)
* add gridwise gemm v4r1 * rename * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * use sfc in shuffling * remove hardcode * remove hardcode * refactor * fix build * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * format * clean * adding gemm+reduce * adding profiler for gemm+reduce * adding gemm+reduce profiler * fix build * clean up * gemm+reduce * fix build * update DeviceGemm_Xdl_CShuffle; update enum to enum class * clean up * add test for gemm+reduce * clean up * refactor * fix build * fix build
This commit is contained in:
@@ -5,17 +5,18 @@
|
||||
#include <cstring>
|
||||
|
||||
int profile_gemm(int, char*[]);
|
||||
int profile_batched_gemm(int, char*[]);
|
||||
int profile_gemm_bias_2d(int, char*[]);
|
||||
int profile_gemm_bias_relu(int, char*[]);
|
||||
int profile_gemm_bias_relu_add(int, char*[]);
|
||||
int profile_gemm_reduce(int, char*[]);
|
||||
int profile_batched_gemm(int, char*[]);
|
||||
int profile_grouped_gemm(int, char*[]);
|
||||
int profile_conv_fwd(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu_add(int, char*[]);
|
||||
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
|
||||
int profile_conv_bwd_data(int, char*[]);
|
||||
int profile_reduce(int, char*[]);
|
||||
int profile_grouped_gemm(int, char*[]);
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -35,10 +36,18 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_gemm_bias_relu_add(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "gemm_reduce") == 0)
|
||||
{
|
||||
return profile_gemm_reduce(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "batched_gemm") == 0)
|
||||
{
|
||||
return profile_batched_gemm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "grouped_gemm") == 0)
|
||||
{
|
||||
profile_grouped_gemm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_fwd") == 0)
|
||||
{
|
||||
return profile_conv_fwd(argc, argv);
|
||||
@@ -63,10 +72,6 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_reduce(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "grouped_gemm") == 0)
|
||||
{
|
||||
return profile_grouped_gemm(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
// clang-format off
|
||||
@@ -74,13 +79,14 @@ int main(int argc, char* argv[])
|
||||
" gemm_bias_2d: GEMM+Bias(2D)\n"
|
||||
" gemm_bias_relu: GEMM+Bias+ReLU\n"
|
||||
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
|
||||
" gemm_reduce: GEMM+Reduce\n"
|
||||
" grouped_gemm: Grouped Gemm\n"
|
||||
" conv_fwd: ForwardConvolution\n"
|
||||
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
|
||||
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
|
||||
" conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n"
|
||||
" conv_bwd: BackwardConvolution\n"
|
||||
" grouped_gemm: Grouped Gemm\n"
|
||||
" reduce: REDUCE\n");
|
||||
" reduce: Reduce\n");
|
||||
// clang-format on
|
||||
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user