mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Grouped Gemm + SplitK + simplified Kernel Args (#669)
* simplify karg in device/grid split-k op
* fix mk_kn_mn instances
* add more instances
* B2C with 3D grid for KSplit
* Remove unused code.
* Use default B2C (3D grid) in grid gemm v2r4r2.
* Device gemm splitk use B2C map.
* Device GroupedGemmXdlSplitKCShuffle
* Example for GroupedGemm Xdl SplitK
* Introduce Device GroupedGemmSplitK
* Fix updating kbatch size.
* Add instance mk-nk-mn
* Enable set kbatch in profiler.
* Add GGemmSplitK mk-kn-mn instances
* Add more instances & split into multiple files.
* minor fix
* tuning
* clean
* disabled failed instances
* use pipe v2
* Ignore arg on not supported arch.
* fix warning
---------
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Jing Zhang <jizhan@amd.com>
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
[ROCm/composable_kernel commit: 8bb2bb4a05]
This commit is contained in:
@@ -52,20 +52,24 @@ std::vector<int> argToIntArray(char* input)
|
||||
|
||||
int profile_grouped_gemm(int argc, char* argv[])
|
||||
{
|
||||
if(!(argc == 14))
|
||||
if(argc < 14)
|
||||
{
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n");
|
||||
std::cout
|
||||
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
|
||||
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
|
||||
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
|
||||
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
|
||||
<< " 2: A[k, m] * B[k, n] = C[m, n];\n"
|
||||
<< " 3: A[k, m] * B[n, k] = C[m, n])\n"
|
||||
<< "arg4: verification (0: no; 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "arg15: kbatch value (default 4)\n"
|
||||
<< std::endl;
|
||||
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -83,6 +87,7 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
const auto StrideAs = argToIntArray(argv[11]);
|
||||
const auto StrideBs = argToIntArray(argv[12]);
|
||||
const auto StrideCs = argToIntArray(argv[13]);
|
||||
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
|
||||
|
||||
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
@@ -101,7 +106,8 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs);
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
@@ -120,7 +126,8 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs);
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
@@ -139,7 +146,8 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs);
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
@@ -158,7 +166,8 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs);
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user