From 0b4a5b6b994380bebc6a957096871ea8d2af8255 Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Thu, 3 Apr 2025 14:22:43 -0700 Subject: [PATCH] CkProfiler StreamK GemmUniversal Fix and Split Gemm_universal Test (#2044) * fix and split gemm_universal test * clang * Update test_gemm_universal_ut_cases_bf16.inc * Update test_gemm_universal_xdl_bf16.cpp * Update test_gemm_universal_ut_cases_fp16.inc [ROCm/composable_kernel commit: 7142d8003c6a99f952a62bbd0b90d5f0261fc807] --- .../profile_gemm_universal_streamk_impl.hpp | 2 +- test/gemm_universal/CMakeLists.txt | 15 ++- ... => test_gemm_universal_ut_cases_bf16.inc} | 60 +++------- .../test_gemm_universal_ut_cases_fp16.inc | 99 +++++++++++++++ .../test_gemm_universal_ut_cases_fp8.inc | 113 ++++++++++++++++++ ...l.cpp => test_gemm_universal_xdl_bf16.cpp} | 34 ++---- .../test_gemm_universal_xdl_fp16.cpp | 82 +++++++++++++ .../test_gemm_universal_xdl_fp8.cpp | 71 +++++++++++ .../test_gemm_universal_streamk_util.hpp | 12 +- 9 files changed, 409 insertions(+), 79 deletions(-) mode change 100644 => 100755 profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp mode change 100644 => 100755 test/gemm_universal/CMakeLists.txt rename test/gemm_universal/{test_gemm_universal_ut_cases.inc => test_gemm_universal_ut_cases_bf16.inc} (75%) create mode 100644 test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc create mode 100644 test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc rename test/gemm_universal/{test_gemm_universal_xdl.cpp => test_gemm_universal_xdl_bf16.cpp} (61%) create mode 100644 test/gemm_universal/test_gemm_universal_xdl_fp16.cpp create mode 100644 test/gemm_universal/test_gemm_universal_xdl_fp8.cpp diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp old mode 100644 new mode 100755 index d145ab1766..e625fae808 --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -166,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification, 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, // 2:2-tile Stream-K + DP - if(Grid_size != -1) + if(Grid_size == -1) { grid_size_list = {Grid_size}; } diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt old mode 100644 new mode 100755 index 4aab6323cc..cf5c68e220 --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,4 +1,15 @@ -add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp) +add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance) + target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance) endif() + +add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance) +endif() + diff --git a/test/gemm_universal/test_gemm_universal_ut_cases.inc b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc similarity index 75% rename from test/gemm_universal/test_gemm_universal_ut_cases.inc rename to test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc index 9a21666856..8a6c672a9f 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc @@ -1,6 +1,6 @@ #pragma once -TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) } } -TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) } } -TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) } } -TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) } } -TYPED_TEST(TestGemmUniversal_MK_KN, Regular) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, Regular) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, Regular) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -207,35 +207,3 @@ TYPED_TEST(TestGemmUniversal_MK_NK, Regular) for(int M : Ms) this->Run(M, N, K, StrideA, StrideB, StrideC); } - -TYPED_TEST(TestGemmUniversal_KM_KN, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - { - int StrideA = M; - this->Run(M, N, K, StrideA, StrideB, StrideC); - } -} - -TYPED_TEST(TestGemmUniversal_KM_NK, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - { - int StrideA = M; - this->Run(M, N, K, StrideA, StrideB, StrideC); - } -} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc new file mode 100644 index 0000000000..b61ea0e6b4 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc @@ -0,0 +1,99 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc new file mode 100644 index 0000000000..b831e15e9c --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc @@ -0,0 +1,113 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal/test_gemm_universal_xdl.cpp b/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp similarity index 61% rename from test/gemm_universal/test_gemm_universal_xdl.cpp rename to test/gemm_universal/test_gemm_universal_xdl_bf16.cpp index b872d7089a..8fde65657a 100644 --- a/test/gemm_universal/test_gemm_universal_xdl.cpp +++ b/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp @@ -7,8 +7,6 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "test_gemm_universal_util.hpp" -using F8 = ck::f8_t; -using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; @@ -29,25 +27,25 @@ struct tuple_concat, std::tuple> } // namespace template -class TestGemmUniversal_MK_KN +class TestGemmUniversal_BF16_MK_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_MK_NK +class TestGemmUniversal_BF16_MK_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_KM_KN +class TestGemmUniversal_BF16_KM_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_KM_NK +class TestGemmUniversal_BF16_KM_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; @@ -55,22 +53,12 @@ class TestGemmUniversal_KM_NK // clang-format off using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< BF16, BF16, BF16, BF16> >; using KernelTypes_MK_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< BF16, BF16, BF16, BF16> >; @@ -86,9 +74,9 @@ using KernelTypes_KM_KN = ::testing::Types< // clang-format on -TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK); -TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN); -TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); -#include "test_gemm_universal_ut_cases.inc" +#include "test_gemm_universal_ut_cases_bf16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp new file mode 100644 index 0000000000..24f587daf6 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_KM_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_KM_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + +#endif + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + +#endif + std::tuple< F16, F16, F16, F16> + >; + +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp new file mode 100644 index 0000000000..e833ab7825 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP8_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP8_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; + + +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); + + +#include "test_gemm_universal_ut_cases_fp8.inc" diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp index ef3509c0ca..805587a274 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp @@ -44,9 +44,8 @@ class TestGemmUniversal_Streamk : public testing::Test void SetUp() override { - grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; - streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile - // Stream-K+ DP, // {0, 1, 2, 3, 4} + streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile + // Stream-K+ DP, // {0, 1, 2, 3, 4} // 2:2-tile Stream-K + DP } @@ -58,10 +57,9 @@ class TestGemmUniversal_Streamk : public testing::Test const int StrideC) { for(auto streamk_sel : streamk_sel_list) - for(auto grid_size : grid_size_list) - { - RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size); - } + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, -1); + } } void RunSingle(const int M,