mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
add split-k GEMM (#59)
* add DeviceGemmSplitKXdl
* add file device_gemm_splitk_xdl.hpp
* set c matrix zero
* using atomic
* add all tuning parameter to f32 mkkn
* grid size change to 720
* add tunning parameter for NT
* add tunning parameter for TN
* add tunning parameter for TT
* add m=96tunning parameter
* add lost config
* add element wise operation
* fixed MPerBlock=96
* remove marco for slpitk swtich
* add test
* add new line at the end of device_gemm_xdl_instance.hpp
* remove step hack
* seperate split-k instance files
* add tunning parameters
* change disired grid size to parameters
* remove slice length
* add desiredgridsize parameter to ckProfiler
* add losting file device_gemm_xdl_splitk_instance.hpp
* change desired gride size to kbatch
* format
* format
* clean up
* add selection of device_instances
* clean code
* fix build issue
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: Jing Zhang <jizhan@amd.com>
[ROCm/composable_kernel commit: 4be7f0198e]
This commit is contained in:
@@ -35,19 +35,20 @@ enum GemmDataType
|
||||
|
||||
int profile_gemm(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 14)
|
||||
if(!(argc == 14 || argc == 15))
|
||||
{
|
||||
printf("arg1: tensor operation (gemm: GEMM)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, n] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, n] * B[n, k] = C[m, n])\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg8: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: run kernel # of times (>1)\n");
|
||||
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
|
||||
printf("arg14: split k into mulitiple batch\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -65,6 +66,9 @@ int profile_gemm(int argc, char* argv[])
|
||||
const int StrideA = std::stoi(argv[11]);
|
||||
const int StrideB = std::stoi(argv[12]);
|
||||
const int StrideC = std::stoi(argv[13]);
|
||||
int KBatch = 1;
|
||||
if(argc == 15)
|
||||
KBatch = std::stoi(argv[14]);
|
||||
|
||||
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
@@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
@@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? K : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
@@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
@@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? K : StrideB,
|
||||
(StrideC < 0) ? N : StrideC);
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
KBatch);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user