mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Add instances/ckProfiler/client example for fp8/fp16 mixed precision Gemm (#853)
* Add ComputeType arg to splitk device and gridwise ops
* Update for gridwise op compatibility
* Update bf16 and int8 splitk gemm examples with ComputeType
* Add instances
* Update ckProfiler for mixed precision cases
* Add a mixed precision splitK gemm client example
---------
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: eac50708d9]
This commit is contained in:
@@ -23,6 +23,8 @@ enum struct GemmDataType
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F16_F16, // 4
|
||||
F16_F8_F16, // 5
|
||||
};
|
||||
|
||||
#define OP_NAME "gemm_splitk"
|
||||
@@ -33,7 +35,7 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
if(argc != 15)
|
||||
{
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
@@ -65,6 +67,7 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -143,6 +146,38 @@ int profile_gemm_splitk(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user