mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Add splitk gemm fp16 @ fp16 with fp8 compute instances (#983)
* Add ComputeType * Update for compatibility * Add instances * Update profiler api
This commit is contained in:
@@ -25,6 +25,7 @@ enum struct GemmDataType
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F16_F16, // 4
|
||||
F16_F8_F16, // 5
|
||||
F16_F16_F16_F8, // 6
|
||||
};
|
||||
|
||||
#define OP_NAME "gemm_splitk"
|
||||
@@ -35,7 +36,8 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
if(argc != 15)
|
||||
{
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
|
||||
"comp f8)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
@@ -80,7 +82,8 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
auto c_type,
|
||||
auto a_layout,
|
||||
auto b_layout,
|
||||
auto c_layout) {
|
||||
auto c_layout,
|
||||
auto compute_type) {
|
||||
using ADataType = decltype(a_type);
|
||||
using BDataType = decltype(b_type);
|
||||
using AccDataType = decltype(acc_type);
|
||||
@@ -90,6 +93,8 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
using BLayout = decltype(b_layout);
|
||||
using CLayout = decltype(c_layout);
|
||||
|
||||
using ComputeType = decltype(compute_type);
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
|
||||
@@ -100,7 +105,8 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
CLayout,
|
||||
ComputeType>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
@@ -118,68 +124,84 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
|
||||
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
|
||||
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
|
||||
}
|
||||
#if defined CK_ENABLE_FP8
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{});
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F8{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F8{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user