diff --git a/test/gemm_universal/test_gemm_universal_xdl.cpp b/test/gemm_universal/test_gemm_universal_xdl.cpp index 648574839f..2fb9f17f2d 100644 --- a/test/gemm_universal/test_gemm_universal_xdl.cpp +++ b/test/gemm_universal/test_gemm_universal_xdl.cpp @@ -41,16 +41,24 @@ class TestGemmUniversal_MK_NK }; // clang-format off -using KernelTypes = ::testing::Types< +using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType std::tuple< F16, F16, F16, F16>, std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F16, F16, F16>, std::tuple< BF16, BF16, BF16, BF16> >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F16, F16, F16, F16>, + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< BF16, BF16, BF16, BF16>, + std::tuple< F8, F8, F8, BF16> + >; // clang-format on -TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes); -TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes); +TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK); #include "test_gemm_universal_ut_cases.inc"