From 539412dd89f5f0ef92d83cdad5ee85bf65789f08 Mon Sep 17 00:00:00 2001 From: Jianfeng Yan Date: Thu, 7 Apr 2022 13:17:15 -0500 Subject: [PATCH] Fix typo in batched gemm profiler (#176) * forgot passing BatchedCount in some profiler_batched_gemm * delete default BatchCount [ROCm/composable_kernel commit: ac0d806650280b770bde1dac952535b34a2d4f5d] --- .../include/profile_batched_gemm_impl.hpp | 2 +- profiler/src/profile_batched_gemm.cpp | 21 ++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp index 51fcba910f..7abbf7a042 100644 --- a/profiler/include/profile_batched_gemm_impl.hpp +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -70,7 +70,7 @@ bool profile_batched_gemm_impl(int do_verification, int StrideA, int StrideB, int StrideC, - int BatchCount = 1) + int BatchCount) { bool pass = true; diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp index 3021559897..2a806b0818 100644 --- a/profiler/src/profile_batched_gemm.cpp +++ b/profiler/src/profile_batched_gemm.cpp @@ -128,7 +128,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -147,7 +148,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -206,7 +208,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -225,7 +228,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -284,7 +288,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -303,7 +308,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -362,7 +368,8 @@ int profile_batched_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + BatchCount); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) {