mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Added bf16 instances grouped gemm fixed nk (#1825)
* Feat: Add bf16 input instances
* feat: Add BF16 profiler code
* fix: reorder enum types
* fix: CI fail due to clang-format
* fix: clang script format issue
* fix: clang format broke cmakelist file
[ROCm/composable_kernel commit: e7dce4d247]
This commit is contained in:
@@ -17,10 +17,11 @@ enum struct GemmMatrixLayout
|
||||
|
||||
enum struct GemmDataType
|
||||
{
|
||||
BF16_I8_BF16, // 0
|
||||
F16_F16_F16, // 1
|
||||
F16_F8_F16, // 2
|
||||
F16_I8_F16, // 3
|
||||
BF16_I8_BF16, // 0
|
||||
F16_F16_F16, // 1
|
||||
F16_F8_F16, // 2
|
||||
F16_I8_F16, // 3
|
||||
BF16_BF16_BF16 // 4
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_gemm_fixed_nk"
|
||||
@@ -182,7 +183,7 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP8
|
||||
#if defined(CK_ENABLE_INT8)
|
||||
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
@@ -226,12 +227,58 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
1,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
#endif
|
||||
#endif // CK_ENABLE_INT8
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
#if defined(CK_ENABLE_INT8)
|
||||
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
@@ -279,8 +326,8 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#endif // CK_ENABLE_INT8
|
||||
#endif // CK_ENABLE_BF16
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
|
||||
|
||||
Reference in New Issue
Block a user