mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 00:39:02 +00:00
[CK] Remove duplicated XDL/WMMA tests (#4415) ## Motivation When we started the RDNA4 support, the XDL instances were not supporting WMMA instructions, so we duplicated some tests. In this issue, we simplified most of the duplicated test files into common test files. ## Technical Details The following tests were unified: - `batched_gemm` - `batched_gemm_gemm` - `gemm_add` - `gemm_universal` - `grouped_convnd_bwd_data` The following tests were duplicated exactly, and copied into two files with `_xdl` and `_wmma` suffixes. Now they are unified in one single file without suffix: - `gemm_multi_abd` - `gemm_b_scale` There is still an apparent duplication which is a special case, namely `test_grouped_convnd_bwd_weight_interface_{suffix}` where `{suffix}` is `xdl` or `wmma`. However, the WMMA code relies on an old implementation, and is expected to be removed in the future. In addition, it differs from the XDL implementation significantly. Therefore, it was decided to keep both files separate instead of attempting any unification. ## Test Plan `CMakeLists.txt` files were modified to support the new, unified tests. In particular, testing was done for `gfx90a`, `gfx1201` and `gfx11` architectures. ## Test Result All tests passed successfully on all three tested architectures. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Fernando Jiménez <fernando.jimenez@streamhpc.com>
155 lines
8.9 KiB
C++
155 lines
8.9 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include <tuple>
|
|
|
|
#include "gtest/gtest.h"
|
|
#include "ck/ck.hpp"
|
|
#include "profiler/profile_gemm_multi_abd_impl.hpp"
|
|
#include "test_gemm_common.hpp"
|
|
|
|
namespace ck {
|
|
namespace test {
|
|
|
|
using Row = ck::tensor_layout::gemm::RowMajor;
|
|
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
|
|
|
using I8 = int8_t;
|
|
using BF16 = ck::bhalf_t;
|
|
|
|
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
|
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
|
using Add = ck::tensor_operation::element_wise::Add;
|
|
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
|
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
|
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
|
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
|
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
|
|
|
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Row, Row>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8, BF16>,
|
|
ck::Tuple<BF16>,
|
|
BF16,
|
|
PassThrough,
|
|
Multiply,
|
|
Add>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Col, Col>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8, BF16>,
|
|
ck::Tuple<BF16>,
|
|
BF16,
|
|
PassThrough,
|
|
Multiply,
|
|
Add>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Row, Row>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8, BF16>,
|
|
ck::Tuple<BF16>,
|
|
BF16,
|
|
PassThrough,
|
|
Multiply,
|
|
AddFastGelu>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Col, Col>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8, BF16>,
|
|
ck::Tuple<BF16>,
|
|
BF16,
|
|
PassThrough,
|
|
Multiply,
|
|
AddFastGelu>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Row, Row>,
|
|
ck::Tuple<>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8, BF16>,
|
|
ck::Tuple<>,
|
|
BF16,
|
|
PassThrough,
|
|
Multiply,
|
|
FastGelu>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Col, Col>,
|
|
ck::Tuple<>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8, BF16>,
|
|
ck::Tuple<>,
|
|
BF16,
|
|
PassThrough,
|
|
Multiply,
|
|
FastGelu>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Row, Row>,
|
|
ck::Tuple<>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8, BF16>,
|
|
ck::Tuple<>,
|
|
BF16,
|
|
PassThrough,
|
|
Multiply,
|
|
PassThrough>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Col, Col>,
|
|
ck::Tuple<>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8, BF16>,
|
|
ck::Tuple<>,
|
|
BF16,
|
|
PassThrough,
|
|
Multiply,
|
|
PassThrough>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<Row, Row>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8>,
|
|
ck::Tuple<BF16, BF16>,
|
|
BF16,
|
|
PassThrough,
|
|
PassThrough,
|
|
MultiplyAddFastGelu>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<Row, Row>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8>,
|
|
ck::Tuple<BF16, BF16>,
|
|
BF16,
|
|
PassThrough,
|
|
PassThrough,
|
|
MultiplyAdd>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8>,
|
|
ck::Tuple<BF16>,
|
|
BF16,
|
|
PassThrough,
|
|
PassThrough,
|
|
MultiplyFastGelu>,
|
|
std::tuple<ck::Tuple<Row>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<Row>,
|
|
ck::Tuple<BF16>,
|
|
ck::Tuple<I8>,
|
|
ck::Tuple<BF16>,
|
|
BF16,
|
|
PassThrough,
|
|
PassThrough,
|
|
Multiply>>;
|
|
|
|
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
|
|
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }
|
|
|
|
} // namespace test
|
|
} // namespace ck
|