From 96f752aba9e040c483a9e320f9481cef050e7306 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:19:33 -0700 Subject: [PATCH] Fix gemm_splitk test, add hip_check_error after kernel calls in kernel_launch. (#951) * Added error check after kernel launch (#919) Co-authored-by: Xiaodong Wang Co-authored-by: Xiaodong Wang * remove M=0 test cases for test_gemm_splitk --------- Co-authored-by: Xiaodong Wang Co-authored-by: Xiaodong Wang [ROCm/composable_kernel commit: bc1108bb3ee4e18c0417609c215e1779b37f8d39] --- include/ck/host_utility/kernel_launch.hpp | 8 ++++++++ test/gemm_split_k/test_gemm_splitk_ut_cases.inc | 8 ++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 3d27103dcb..df311620e1 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -34,6 +34,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #endif // warm up kernel<<>>(args...); + hip_check_error(hipGetLastError()); const int nrepeat = 10; #if DEBUG_LOG @@ -50,6 +51,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, for(int i = 0; i < nrepeat; ++i) { kernel<<>>(args...); + hip_check_error(hipGetLastError()); } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); @@ -64,11 +66,13 @@ float launch_and_time_kernel(const StreamConfig& stream_config, else { kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; } #else kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; #endif @@ -101,6 +105,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, // warm up preprocess(); kernel<<>>(args...); + hip_check_error(hipGetLastError()); const int nrepeat = 10; #if DEBUG_LOG @@ -118,6 +123,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { preprocess(); kernel<<>>(args...); + hip_check_error(hipGetLastError()); } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); @@ -133,11 +139,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { preprocess(); kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; } #else kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; #endif diff --git a/test/gemm_split_k/test_gemm_splitk_ut_cases.inc b/test/gemm_split_k/test_gemm_splitk_ut_cases.inc index 54b9c6c9e3..d583a3e377 100644 --- a/test/gemm_split_k/test_gemm_splitk_ut_cases.inc +++ b/test/gemm_split_k/test_gemm_splitk_ut_cases.inc @@ -2,7 +2,7 @@ TYPED_TEST(TestGemmSplitK_MK_KN, SmallM) { - std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; constexpr int K = 320; @@ -16,7 +16,7 @@ TYPED_TEST(TestGemmSplitK_MK_KN, SmallM) TYPED_TEST(TestGemmSplitK_MK_NK, SmallM) { - std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; constexpr int K = 320; @@ -30,7 +30,7 @@ TYPED_TEST(TestGemmSplitK_MK_NK, SmallM) TYPED_TEST(TestGemmSplitK_KM_KN, SmallM) { - std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; constexpr int K = 320; @@ -43,7 +43,7 @@ TYPED_TEST(TestGemmSplitK_KM_KN, SmallM) TYPED_TEST(TestGemmSplitK_KM_NK, SmallM) { - std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; constexpr int K = 320;