diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp index 51fcba910f..7abbf7a042 100644 --- a/profiler/include/profile_batched_gemm_impl.hpp +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -70,7 +70,7 @@ bool profile_batched_gemm_impl(int do_verification, int StrideA, int StrideB, int StrideC, - int BatchCount = 1) + int BatchCount) { bool pass = true; diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp index 3021559897..2a806b0818 100644 --- a/profiler/src/profile_batched_gemm.cpp +++ b/profiler/src/profile_batched_gemm.cpp @@ -128,7 +128,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -147,7 +148,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -206,7 +208,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -225,7 +228,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -284,7 +288,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -303,7 +308,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -362,7 +368,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) {