mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Add Gemm instances for performance improvement (#1018)
* improve kpad
* more tuning parameters
* f16_f8_fp16
* cut test time
* add f16_f8_fp16
* add f16_f8_f16
* testing instances for skinny cases
* format
* clean
* add fp16_f8_fp16
* clang-format
* add grouped gemm instalces
* fixed profile grouped_gemm
* clean
* clean
* clean
* clean
* clean
* add missing instance func
* fixed inferface
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: root <root@sh5-1e707-rc06-38.mkm.dcgpu>
[ROCm/composable_kernel commit: 98fd41f597]
This commit is contained in:
@@ -27,6 +27,8 @@ enum struct GemmDataType
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F16_F16, // 4
|
||||
F16_F8_F16, // 5
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_gemm"
|
||||
@@ -56,7 +58,7 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
{
|
||||
std::cout
|
||||
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
|
||||
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
|
||||
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8@fp6; 5: f16@f8)\n"
|
||||
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
|
||||
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
|
||||
<< " 2: A[k, m] * B[k, n] = C[m, n];\n"
|
||||
@@ -169,6 +171,46 @@ int profile_grouped_gemm(int argc, char* argv[])
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_impl<ck::f8_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
|
||||
ck::f8_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
|
||||
|
||||
Reference in New Issue
Block a user