diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp old mode 100755 new mode 100644 index e625fae808..d145ab1766 --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -166,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification, 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, // 2:2-tile Stream-K + DP - if(Grid_size == -1) + if(Grid_size != -1) { grid_size_list = {Grid_size}; } diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt old mode 100755 new mode 100644 index cf5c68e220..4aab6323cc --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,15 +1,4 @@ -add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp) +add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance) + target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance) endif() - -add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance) -endif() - -add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance) -endif() - diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc b/test/gemm_universal/test_gemm_universal_ut_cases.inc similarity index 75% rename from test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc rename to test/gemm_universal/test_gemm_universal_ut_cases.inc index 8a6c672a9f..9a21666856 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases.inc @@ -1,6 +1,6 @@ #pragma once -TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM) +TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM) +TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM) +TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM) } } -TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM) +TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM) } } -TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK) +TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK) +TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK) +TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK) } } -TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK) +TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK) } } -TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular) +TYPED_TEST(TestGemmUniversal_MK_KN, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular) +TYPED_TEST(TestGemmUniversal_MK_NK, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -207,3 +207,35 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular) for(int M : Ms) this->Run(M, N, K, StrideA, StrideB, StrideC); } + +TYPED_TEST(TestGemmUniversal_KM_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmUniversal_KM_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc deleted file mode 100644 index b61ea0e6b4..0000000000 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once - -TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM) -{ - std::vector Ms{1, 2, 3, 4, 5, 6}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM) -{ - std::vector Ms{1, 2, 3, 4, 5, 6}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) -{ - std::vector Ms{127}; - constexpr int N = 512; - constexpr int K = 437; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) -{ - std::vector Ms{127}; - constexpr int N = 512; - constexpr int K = 437; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc deleted file mode 100644 index b831e15e9c..0000000000 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc +++ /dev/null @@ -1,113 +0,0 @@ -#pragma once - -TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM) -{ - std::vector Ms{1, 2, 3, 4, 5, 6}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM) -{ - std::vector Ms{1, 2, 3, 4, 5, 6}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK) -{ - std::vector Ms{127}; - constexpr int N = 512; - constexpr int K = 437; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK) -{ - std::vector Ms{127}; - constexpr int N = 512; - constexpr int K = 437; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} diff --git a/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp b/test/gemm_universal/test_gemm_universal_xdl.cpp similarity index 61% rename from test/gemm_universal/test_gemm_universal_xdl_bf16.cpp rename to test/gemm_universal/test_gemm_universal_xdl.cpp index 8fde65657a..b872d7089a 100644 --- a/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp +++ b/test/gemm_universal/test_gemm_universal_xdl.cpp @@ -7,6 +7,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "test_gemm_universal_util.hpp" +using F8 = ck::f8_t; +using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; @@ -27,25 +29,25 @@ struct tuple_concat, std::tuple> } // namespace template -class TestGemmUniversal_BF16_MK_KN +class TestGemmUniversal_MK_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_BF16_MK_NK +class TestGemmUniversal_MK_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_BF16_KM_KN +class TestGemmUniversal_KM_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_BF16_KM_NK +class TestGemmUniversal_KM_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; @@ -53,12 +55,22 @@ class TestGemmUniversal_BF16_KM_NK // clang-format off using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - + std::tuple< F16, F16, F16, F16>, +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif std::tuple< BF16, BF16, BF16, BF16> >; using KernelTypes_MK_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - + std::tuple< F16, F16, F16, F16>, +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif std::tuple< BF16, BF16, BF16, BF16> >; @@ -74,9 +86,9 @@ using KernelTypes_KM_KN = ::testing::Types< // clang-format on -TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); -TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); -TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); +TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK); -#include "test_gemm_universal_ut_cases_bf16.inc" +#include "test_gemm_universal_ut_cases.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp deleted file mode 100644 index 24f587daf6..0000000000 --- a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp +++ /dev/null @@ -1,82 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "gtest/gtest.h" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "test_gemm_universal_util.hpp" - -using F8 = ck::f8_t; -using F16 = ck::half_t; - -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -namespace { - -template -struct tuple_concat; - -template -struct tuple_concat, std::tuple> -{ - using type = std::tuple; -}; - -} // namespace - -template -class TestGemmUniversal_FP16_MK_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP16_MK_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP16_KM_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP16_KM_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -// clang-format off -using KernelTypes_MK_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - -#endif - std::tuple< F16, F16, F16, F16> - >; -using KernelTypes_MK_NK = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - -#endif - std::tuple< F16, F16, F16, F16> - >; - -// clang-format on - -TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); - -#include "test_gemm_universal_ut_cases_fp16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp deleted file mode 100644 index e833ab7825..0000000000 --- a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp +++ /dev/null @@ -1,71 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "gtest/gtest.h" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "test_gemm_universal_util.hpp" - -using F8 = ck::f8_t; -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -namespace { - -template -struct tuple_concat; - -template -struct tuple_concat, std::tuple> -{ - using type = std::tuple; -}; - -} // namespace - -template -class TestGemmUniversal_FP8_MK_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP8_MK_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -// clang-format off -using KernelTypes_MK_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif - // Fallback test type when FP8 is not enabled - std::tuple< F16, F16, F16, F16> - >; -using KernelTypes_MK_NK = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif - // Fallback test type when FP8 is not enabled - std::tuple< F16, F16, F16, F16> - >; - - -TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); - - -#include "test_gemm_universal_ut_cases_fp8.inc" diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp index 805587a274..ef3509c0ca 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp @@ -44,8 +44,9 @@ class TestGemmUniversal_Streamk : public testing::Test void SetUp() override { - streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile - // Stream-K+ DP, // {0, 1, 2, 3, 4} + grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; + streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile + // Stream-K+ DP, // {0, 1, 2, 3, 4} // 2:2-tile Stream-K + DP } @@ -57,9 +58,10 @@ class TestGemmUniversal_Streamk : public testing::Test const int StrideC) { for(auto streamk_sel : streamk_sel_list) - { - RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, -1); - } + for(auto grid_size : grid_size_list) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size); + } } void RunSingle(const int M,