[rocm-libraries] ROCm/rocm-libraries#4415 (commit b3b4af7)

[CK] Remove duplicated XDL/WMMA tests
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

When we started the RDNA4 support, the XDL instances were not supporting
WMMA instructions, so we duplicated some tests.

In this issue, we simplified most of the duplicated test files into
common test files.

## Technical Details

The following tests were unified:

- `batched_gemm`

- `batched_gemm_gemm`

- `gemm_add`

- `gemm_universal`

- `grouped_convnd_bwd_data`

The following tests were duplicated exactly, and copied into two files
with `_xdl` and `_wmma` suffixes. Now they are unified in one single
file without suffix:

- `gemm_multi_abd`

- `gemm_b_scale`

There is still an apparent duplication which is a special case, namely
`test_grouped_convnd_bwd_weight_interface_{suffix}` where `{suffix}` is
`xdl` or `wmma`.
However, the WMMA code relies on an old implementation, and is expected
to be removed in the future. In addition, it differs from the XDL
implementation significantly.
Therefore, it was decided to keep both files separate instead of
attempting any unification.

## Test Plan

`CMakeLists.txt` files were modified to support the new, unified tests.
In particular, testing was done for `gfx90a`, `gfx1201` and `gfx11`
architectures.

## Test Result

All tests passed successfully on all three tested architectures.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
JP-Fernando
2026-02-25 23:23:02 +00:00
committed by assistant-librarian[bot]
parent 6bf2423685
commit 9a32f0ea19
31 changed files with 307 additions and 1340 deletions

View File

@@ -1,32 +1,23 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_gtest_executable(test_gemm_universal_wmma_fp16 test_gemm_universal_wmma_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_wmma_fp16 PRIVATE utility device_gemm_universal_instance)
endif()
add_gtest_executable(test_gemm_universal_wmma_bf16 test_gemm_universal_wmma_bf16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_wmma_bf16 PRIVATE utility device_gemm_universal_instance)
endif()
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
# the instance library if there's no instances present for the current arch.
if (CK_USE_XDL OR CK_USE_WMMA)
add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance)
endif()
add_gtest_executable(test_gemm_universal_wmma_fp8 test_gemm_universal_wmma_fp8.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_wmma_fp8 PRIVATE utility device_gemm_universal_instance)
endif()
add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_fp8.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance)
endif()
add_gtest_executable(test_gemm_universal_xdl_fp16 test_gemm_universal_xdl_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_xdl_fp16 PRIVATE utility device_gemm_universal_instance)
endif()
add_gtest_executable(test_gemm_universal_xdl_fp8 test_gemm_universal_xdl_fp8.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_xdl_fp8 PRIVATE utility device_gemm_universal_instance)
endif()
add_gtest_executable(test_gemm_universal_xdl_bf16 test_gemm_universal_xdl_bf16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_xdl_bf16 PRIVATE utility device_gemm_universal_instance)
add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_bf16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance)
endif()
endif()

View File

@@ -55,7 +55,8 @@ class TestGemmUniversal_BF16_KM_NK
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< BF16, BF16, BF16, BF16>
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_MK_NK = ::testing::Types<
@@ -66,11 +67,6 @@ using KernelTypes_MK_NK = ::testing::Types<
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_KM_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_KM_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8)
@@ -78,6 +74,12 @@ using KernelTypes_KM_NK = ::testing::Types<
#endif
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_KM_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< BF16, BF16, BF16, BF16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN);

View File

@@ -44,31 +44,34 @@ class TestGemmUniversal_FP8_MK_NK
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) && !defined(CK_USE_WMMA_FP8)
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>,
#endif
std::tuple< F8, F16, F16, F16>>;
#elif defined(CK_USE_WMMA_FP8)
// Fallback test type when WMMA FP8 is used
std::tuple< F8, F8, F8, BF16>>;
#else
// Fallback test type when FP8 is not enabled
std::tuple< F16, F16, F16, F16>
>;
std::tuple< F16, F16, F16, F16>>;
#endif
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) && !defined(CK_USE_WMMA_FP8)
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
std::tuple< F8, F8, F8, BF16>,
#endif
std::tuple< F8, F16, F16, F16>>;
#elif defined(CK_USE_WMMA_FP8)
// Fallback test type when WMMA FP8 is used
std::tuple< F8, F8, F8, BF16>>;
#else
// Fallback test type when FP8 is not enabled
std::tuple< F16, F16, F16, F16>
>;
std::tuple< F16, F16, F16, F16>>;
#endif
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK);
#include "test_gemm_universal_ut_cases_fp8.inc"
int main(int argc, char** argv)
{

View File

@@ -1,78 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_universal_util.hpp"
ck::index_t param_mask = 0xffff;
ck::index_t instance_index = -1;
#if defined(CK_USE_WMMA_FP8)
using F8 = ck::f8_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmUniversal_FP8_MK_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_FP8_MK_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F8, F8, F8, BF16>
>;
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F8, F8, F8, BF16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK);
#include "test_gemm_universal_ut_cases_fp8.inc"
#endif // CK_USE_WMMA_FP8
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}

View File

@@ -1,99 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_universal_util.hpp"
ck::index_t param_mask = 0xffff;
ck::index_t instance_index = -1;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmUniversal_BF16_MK_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_BF16_MK_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_BF16_KM_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_BF16_KM_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_KM_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< BF16, BF16, BF16, BF16>
>;
using KernelTypes_KM_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< BF16, BF16, BF16, BF16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK);
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN);
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK);
#include "test_gemm_universal_ut_cases_bf16.inc"
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}

View File

@@ -1,111 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_universal_util.hpp"
ck::index_t param_mask = 0xffff;
ck::index_t instance_index = -1;
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmUniversal_FP16_MK_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_FP16_MK_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_FP16_KM_KN
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmUniversal_FP16_KM_NK
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes_MK_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
#endif
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_MK_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
std::tuple< F16, F8, F16, F16>,
std::tuple< F8, F16, F16, F16>,
#endif
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_KM_NK = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>
>;
using KernelTypes_KM_KN = ::testing::Types<
// ADataType, BDataType, ComputeDataType, CDataType
std::tuple< F16, F16, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_NK, KernelTypes_KM_NK);
TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_KN, KernelTypes_KM_KN);
#include "test_gemm_universal_ut_cases_fp16.inc"
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}