mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Add fp8 gemm instances (#920)
* Add fp8 gemm instances * Update instance naming
This commit is contained in:
@@ -23,6 +23,7 @@ enum struct GemmDataType
|
||||
F16_F16_F16, // 1
|
||||
BF16_BF16_BF16, // 2
|
||||
INT8_INT8_INT8, // 3
|
||||
F8_F8_F8, // 4
|
||||
};
|
||||
|
||||
#define OP_NAME "gemm"
|
||||
@@ -31,7 +32,7 @@ enum struct GemmDataType
|
||||
static void print_helper_msg()
|
||||
{
|
||||
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
|
||||
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
|
||||
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8)\n"
|
||||
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
|
||||
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
|
||||
<< " 2: A[k, m] * B[k, n] = C[m, n];\n"
|
||||
@@ -76,6 +77,9 @@ int profile_gemm(int argc, char* argv[])
|
||||
using INT8 = int8_t;
|
||||
using INT32 = int32_t;
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -194,6 +198,24 @@ int profile_gemm(int argc, char* argv[])
|
||||
{
|
||||
return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(Row{}, Row{}, Row{}, F8{}, F8{}, F32{}, F8{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(Row{}, Col{}, Row{}, F8{}, F8{}, F32{}, F8{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
return profile(Col{}, Row{}, Row{}, F8{}, F8{}, F32{}, F8{});
|
||||
}
|
||||
else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
return profile(Col{}, Col{}, Row{}, F8{}, F8{}, F32{}, F8{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user