Fix typo in batched gemm profiler (#176)

* forgot passing BatchedCount in some profiler_batched_gemm

* delete default BatchCount

[ROCm/composable_kernel commit: ac0d806650]
This commit is contained in:
Jianfeng Yan
2022-04-07 13:17:15 -05:00
committed by GitHub
parent 23e9f358bb
commit 539412dd89
2 changed files with 15 additions and 8 deletions

View File

@@ -70,7 +70,7 @@ bool profile_batched_gemm_impl(int do_verification,
int StrideA,
int StrideB,
int StrideC,
int BatchCount = 1)
int BatchCount)
{
bool pass = true;

View File

@@ -128,7 +128,8 @@ int profile_batched_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
BatchCount);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
@@ -147,7 +148,8 @@ int profile_batched_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
BatchCount);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
@@ -206,7 +208,8 @@ int profile_batched_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
BatchCount);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{
@@ -225,7 +228,8 @@ int profile_batched_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
BatchCount);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
@@ -284,7 +288,8 @@ int profile_batched_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
BatchCount);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
@@ -303,7 +308,8 @@ int profile_batched_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
BatchCount);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{
@@ -362,7 +368,8 @@ int profile_batched_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
BatchCount);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{