Add Gemm instances for performance improvement (#1018)

* improve kpad

* more tuning parameters

* f16_f8_fp16

* cut test time

* add f16_f8_fp16

* add f16_f8_f16

* testing instances for skinny cases

* format

* clean

* add fp16_f8_fp16

* clang-format

* add grouped gemm instalces

* fixed profile grouped_gemm

* clean

* clean

* clean

* clean

* clean

* add missing instance func

* fixed inferface

---------

Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: root <root@sh5-1e707-rc06-38.mkm.dcgpu>

[ROCm/composable_kernel commit: 98fd41f597]
This commit is contained in:
zjing14
2023-11-07 09:09:58 -06:00
committed by GitHub
parent 8b3a547a76
commit be753a8db1
28 changed files with 1346 additions and 338 deletions

View File

@@ -27,6 +27,8 @@ enum struct GemmDataType
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
};
#define OP_NAME "grouped_gemm"
@@ -56,7 +58,7 @@ int profile_grouped_gemm(int argc, char* argv[])
{
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8@fp6; 5: f16@f8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< " 2: A[k, m] * B[k, n] = C[m, n];\n"
@@ -169,6 +171,46 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::f8_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::f8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");