mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#4415 (commit b3b4af7)
[CK] Remove duplicated XDL/WMMA tests
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Motivation
When we started the RDNA4 support, the XDL instances were not supporting
WMMA instructions, so we duplicated some tests.
In this issue, we simplified most of the duplicated test files into
common test files.
## Technical Details
The following tests were unified:
- `batched_gemm`
- `batched_gemm_gemm`
- `gemm_add`
- `gemm_universal`
- `grouped_convnd_bwd_data`
The following tests were duplicated exactly, and copied into two files
with `_xdl` and `_wmma` suffixes. Now they are unified in one single
file without suffix:
- `gemm_multi_abd`
- `gemm_b_scale`
There is still an apparent duplication which is a special case, namely
`test_grouped_convnd_bwd_weight_interface_{suffix}` where `{suffix}` is
`xdl` or `wmma`.
However, the WMMA code relies on an old implementation, and is expected
to be removed in the future. In addition, it differs from the XDL
implementation significantly.
Therefore, it was decided to keep both files separate instead of
attempting any unification.
## Test Plan
`CMakeLists.txt` files were modified to support the new, unified tests.
In particular, testing was done for `gfx90a`, `gfx1201` and `gfx11`
architectures.
## Test Result
All tests passed successfully on all three tested architectures.
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6bf2423685
commit
9a32f0ea19
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_gtest_executable(test_gemm_multi_abd_wmma test_gemm_multi_abd_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multi_abd_wmma PRIVATE utility device_gemm_multi_abd_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_multi_abd_xdl test_gemm_multi_abd_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multi_abd_xdl PRIVATE utility device_gemm_multi_abd_instance)
|
||||
endif()
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_gemm_multi_abd test_gemm_multi_abd.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multi_abd PRIVATE utility device_gemm_multi_abd_instance)
|
||||
endif()
|
||||
endif()
|
||||
@@ -1,154 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_multi_abd_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using I8 = int8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
|
||||
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
|
||||
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user