From daabe29bff610ce9006b4b7d00653a97e6cbf1b0 Mon Sep 17 00:00:00 2001 From: emezh Date: Fri, 26 Sep 2025 22:55:18 -0400 Subject: [PATCH] fix copy-paste bug in get_matrix_b; re-enable all tests in multi_abd (#2939) [ROCm/composable_kernel commit: 2aa06fbd4509b43334a96d36f96948cc4d2e3c0b] --- .../profiler/profile_gemm_multi_abd_impl.hpp | 2 +- .../test_gemm_multi_abd_wmma.cpp | 85 +++++++++---------- .../test_gemm_multi_abd_xdl.cpp | 85 +++++++++---------- 3 files changed, 83 insertions(+), 89 deletions(-) diff --git a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp index a3c5c6a3ac..46745fd02b 100644 --- a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp @@ -224,7 +224,7 @@ bool profile_gemm_multi_abd_impl(int do_verification, auto get_b_matrix = [&]() -> auto { // in case of pass through we avoid allocating a new // tensor and copying values - if constexpr(is_same_v) + if constexpr(is_same_v) { return bs_k_n(Number<0>{}); } diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp b/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp index a15f95bbf8..42584ecc02 100644 --- a/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp +++ b/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp @@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; -using KernelTypesABD = ::testing::Types< -#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation - std::tuple, +using KernelTypesABD = ::testing::Types, ck::Tuple, ck::Tuple, ck::Tuple, @@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types< PassThrough, Multiply, PassThrough>, -#endif - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAddFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAdd>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - Multiply>>; + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>; TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD); TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); } diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp b/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp index a15f95bbf8..42584ecc02 100644 --- a/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp +++ b/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp @@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; -using KernelTypesABD = ::testing::Types< -#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation - std::tuple, +using KernelTypesABD = ::testing::Types, ck::Tuple, ck::Tuple, ck::Tuple, @@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types< PassThrough, Multiply, PassThrough>, -#endif - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAddFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAdd>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - Multiply>>; + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>; TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD); TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }