mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
fix: preprocessor directives logic error if/else (#1764)
* fix: preprocessors logic error if/else * fix: added macros as preferred by CK team
This commit is contained in:
@@ -21,7 +21,6 @@ enum struct GemmDataType
|
||||
F16_F16_F16, // 1
|
||||
F16_F8_F16, // 2
|
||||
F16_I8_F16, // 3
|
||||
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_gemm_fixed_nk"
|
||||
@@ -39,7 +38,6 @@ std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@@ -83,14 +81,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
const auto StrideCs = argToIntArray(argv[13]);
|
||||
const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1;
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
|
||||
int n_warmup = 1;
|
||||
int n_iter = 10;
|
||||
if(argc == 17)
|
||||
@@ -99,61 +89,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
n_iter = std::stoi(argv[16]);
|
||||
}
|
||||
|
||||
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
|
||||
if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
|
||||
I8,
|
||||
BF16,
|
||||
F32,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
|
||||
I8,
|
||||
BF16,
|
||||
F32,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
#endif
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
|
||||
F16,
|
||||
F16,
|
||||
F32,
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
@@ -173,10 +114,10 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
|
||||
F16,
|
||||
F16,
|
||||
F32,
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
@@ -194,14 +135,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
#endif
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
|
||||
F8,
|
||||
F16,
|
||||
F32,
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
|
||||
ck::f8_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
@@ -221,10 +161,10 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
|
||||
F8,
|
||||
F16,
|
||||
F32,
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
|
||||
ck::f8_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
@@ -243,13 +183,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
n_iter);
|
||||
}
|
||||
#endif
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
|
||||
#if defined(CK_ENABLE_INT8)
|
||||
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
|
||||
I8,
|
||||
F16,
|
||||
F32,
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
|
||||
int8_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
@@ -269,10 +209,10 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
|
||||
I8,
|
||||
F16,
|
||||
F32,
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
|
||||
int8_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
@@ -290,6 +230,56 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
#endif
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
#if defined(CK_ENABLE_INT8)
|
||||
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
|
||||
int8_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
|
||||
int8_t,
|
||||
ck::bhalf_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user