Added bf16 instances grouped gemm fixed nk (#1825)

* Feat: Add bf16 input instances

* feat: Add BF16 profiler code

* fix: reorder enum types

* fix: CI fail due to clang-format

* fix: clang script format issue

* fix: clang format broke cmakelist file

[ROCm/composable_kernel commit: e7dce4d247]
This commit is contained in:
deepsek
2025-01-20 12:13:09 -05:00
committed by GitHub
parent c54cff82f0
commit dde428cdf9
5 changed files with 256 additions and 11 deletions

View File

@@ -17,10 +17,11 @@ enum struct GemmMatrixLayout
enum struct GemmDataType
{
BF16_I8_BF16, // 0
F16_F16_F16, // 1
F16_F8_F16, // 2
F16_I8_F16, // 3
BF16_I8_BF16, // 0
F16_F16_F16, // 1
F16_F8_F16, // 2
F16_I8_F16, // 3
BF16_BF16_BF16 // 4
};
#define OP_NAME "grouped_gemm_fixed_nk"
@@ -182,7 +183,7 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
#endif
#endif // CK_ENABLE_FP8
#if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
@@ -226,12 +227,58 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
1,
kbatch,
n_warmup,
n_iter);
}
#endif
#endif // CK_ENABLE_INT8
#if defined(CK_ENABLE_BF16)
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
@@ -279,8 +326,8 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
#endif
#endif
#endif // CK_ENABLE_INT8
#endif // CK_ENABLE_BF16
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");