mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK] Remove duplicated XDL/WMMA tests (#4415)
## Motivation
When we started the RDNA4 support, the XDL instances were not supporting
WMMA instructions, so we duplicated some tests.
In this issue, we simplified most of the duplicated test files into
common test files.
## Technical Details
The following tests were unified:
- `batched_gemm`
- `batched_gemm_gemm`
- `gemm_add`
- `gemm_universal`
- `grouped_convnd_bwd_data`
The following tests were duplicated exactly, and copied into two files
with `_xdl` and `_wmma` suffixes. Now they are unified in one single
file without suffix:
- `gemm_multi_abd`
- `gemm_b_scale`
There is still an apparent duplication which is a special case, namely
`test_grouped_convnd_bwd_weight_interface_{suffix}` where `{suffix}` is
`xdl` or `wmma`.
However, the WMMA code relies on an old implementation, and is expected
to be removed in the future. In addition, it differs from the XDL
implementation significantly.
Therefore, it was decided to keep both files separate instead of
attempting any unification.
## Test Plan
`CMakeLists.txt` files were modified to support the new, unified tests.
In particular, testing was done for `gfx90a`, `gfx1201` and `gfx11`
architectures.
## Test Result
All tests passed successfully on all three tested architectures.
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
---------
Co-authored-by: Fernando Jiménez <fernando.jimenez@streamhpc.com>
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_gtest_executable(test_batched_gemm_xdl test_batched_gemm_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_xdl PRIVATE utility device_batched_gemm_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_wmma test_batched_gemm_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_wmma PRIVATE utility device_batched_gemm_instance)
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_batched_gemm test_batched_gemm.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -214,6 +214,13 @@ TEST_F(TestBatchedGemm, bf16)
|
||||
this->params.push_back({68, 68, 68, 2});
|
||||
this->params.push_back({40, 40, 40, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
|
||||
// Tests with larger MNK
|
||||
this->params.push_back({512, 256, 128, 1});
|
||||
this->params.push_back({256, 240, 192, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
this->params.push_back({240, 128, 128, 5});
|
||||
|
||||
this->template Run<ck::bhalf_t>();
|
||||
}
|
||||
#endif
|
||||
@@ -226,7 +233,13 @@ TEST_F(TestBatchedGemm, fp16)
|
||||
this->params.push_back({60, 60, 60, 2});
|
||||
this->params.push_back({68, 68, 68, 2});
|
||||
this->params.push_back({40, 40, 40, 2});
|
||||
|
||||
// Tests with larger MNK
|
||||
this->params.push_back({512, 256, 128, 1});
|
||||
this->params.push_back({256, 240, 192, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
this->params.push_back({240, 128, 128, 5});
|
||||
|
||||
this->template Run<ck::half_t>();
|
||||
}
|
||||
#endif
|
||||
@@ -1,268 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_batched_gemm_impl.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
|
||||
static ck::index_t param_mask = 0xffff;
|
||||
static ck::index_t instance_index = -1;
|
||||
struct GemmParams
|
||||
{
|
||||
ck::index_t M;
|
||||
ck::index_t N;
|
||||
ck::index_t K;
|
||||
ck::index_t BatchCount;
|
||||
};
|
||||
|
||||
class TestBatchedGemm : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
std::vector<GemmParams> params;
|
||||
|
||||
template <typename DataType>
|
||||
void Run()
|
||||
{
|
||||
using namespace ck::tensor_operation::device;
|
||||
|
||||
bool pass = true;
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
{
|
||||
if((param_mask & (1 << i)) == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
auto& param = params[i];
|
||||
const auto M = param.M;
|
||||
const auto N = param.N;
|
||||
const auto K = param.K;
|
||||
const auto BatchCount = param.BatchCount;
|
||||
|
||||
pass = pass && ck::profiler::profile_batched_gemm_impl<DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
Row,
|
||||
Row,
|
||||
Row,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DeviceBatchedGemm<Row,
|
||||
Row,
|
||||
Row,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>(
|
||||
true,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
K,
|
||||
N,
|
||||
N,
|
||||
M * K,
|
||||
K * N,
|
||||
M * N,
|
||||
BatchCount,
|
||||
instance_index);
|
||||
|
||||
pass = pass && ck::profiler::profile_batched_gemm_impl<DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
Row,
|
||||
Col,
|
||||
Row,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DeviceBatchedGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>(
|
||||
true,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
K,
|
||||
K,
|
||||
N,
|
||||
M * K,
|
||||
K * N,
|
||||
M * N,
|
||||
BatchCount,
|
||||
instance_index);
|
||||
|
||||
pass = pass && ck::profiler::profile_batched_gemm_impl<DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DeviceBatchedGemm<Col,
|
||||
Row,
|
||||
Row,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>(
|
||||
true,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
M,
|
||||
N,
|
||||
N,
|
||||
M * K,
|
||||
K * N,
|
||||
M * N,
|
||||
BatchCount,
|
||||
instance_index);
|
||||
|
||||
pass = pass && ck::profiler::profile_batched_gemm_impl<DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DeviceBatchedGemm<Col,
|
||||
Col,
|
||||
Row,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>(
|
||||
true,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
M,
|
||||
K,
|
||||
N,
|
||||
M * K,
|
||||
K * N,
|
||||
M * N,
|
||||
BatchCount,
|
||||
instance_index);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
// #ifdef CK_ENABLE_INT8
|
||||
// TEST_F(TestBatchedGemm, i8)
|
||||
// {
|
||||
// this->params.push_back({64, 64, 64, 2});
|
||||
// this->params.push_back({64, 64, 64, 1});
|
||||
// this->params.push_back({60, 60, 60, 2});
|
||||
// this->params.push_back({68, 68, 68, 2});
|
||||
// this->params.push_back({40, 40, 40, 2});
|
||||
// this->params.push_back({256, 256, 128, 3});
|
||||
// this->template Run<int8_t>();
|
||||
// }
|
||||
// #endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
TEST_F(TestBatchedGemm, bf16)
|
||||
{
|
||||
this->params.push_back({64, 64, 64, 2});
|
||||
this->params.push_back({64, 64, 64, 1});
|
||||
this->params.push_back({40, 40, 40, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
|
||||
// Tests with larger MNK
|
||||
this->params.push_back({512, 256, 128, 1});
|
||||
this->params.push_back({256, 240, 192, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
this->params.push_back({240, 128, 128, 5});
|
||||
this->template Run<ck::bhalf_t>();
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
TEST_F(TestBatchedGemm, fp16)
|
||||
{
|
||||
this->params.push_back({64, 64, 64, 2});
|
||||
this->params.push_back({64, 64, 64, 1});
|
||||
this->params.push_back({40, 40, 40, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
|
||||
// Tests with larger MNK
|
||||
this->params.push_back({512, 256, 128, 1});
|
||||
this->params.push_back({256, 240, 192, 2});
|
||||
this->params.push_back({256, 256, 128, 3});
|
||||
this->params.push_back({240, 128, 128, 5});
|
||||
this->template Run<ck::half_t>();
|
||||
}
|
||||
#endif
|
||||
|
||||
// #ifdef CK_ENABLE_FP32
|
||||
// TEST_F(TestBatchedGemm, fp32)
|
||||
// {
|
||||
// this->params.push_back({64, 64, 64, 2});
|
||||
// this->params.push_back({64, 64, 64, 1});
|
||||
// this->params.push_back({60, 60, 60, 2});
|
||||
// this->params.push_back({68, 68, 68, 2});
|
||||
// this->params.push_back({40, 40, 40, 2});
|
||||
// this->params.push_back({256, 256, 128, 3});
|
||||
// this->template Run<float>();
|
||||
// }
|
||||
// #endif
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16_xdl test_batched_gemm_gemm_fp16_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16_xdl PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_bf16_wmma test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp)
|
||||
add_gtest_executable(test_batched_gemm_gemm_bf16_wmma_cshuffle_v3 test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_bf16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16_wmma test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
target_link_libraries(test_batched_gemm_gemm_bf16_wmma_cshuffle_v3 PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
@@ -136,13 +136,20 @@ using KernelTypes = ::testing::Types<
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16) { this->Run(); }
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16)
|
||||
{
|
||||
this->bench_ = false;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = false;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -151,6 +158,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN)
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = false;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -160,6 +169,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK)
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->bench_ = false;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -168,6 +179,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO)
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->bench_ = false;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -176,6 +189,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM)
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = false;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -184,6 +199,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = false;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -193,6 +210,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->bench_ = false;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -202,6 +221,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO)
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->bench_ = false;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>,
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Col, Row>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16)
|
||||
{
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 768},
|
||||
{256, 256, 128, 128, 768},
|
||||
{512, 512, 64, 64, 768},
|
||||
{512, 512, 128, 128, 768},
|
||||
{1024, 1024, 64, 64, 768},
|
||||
{1024, 1024, 128, 128, 768},
|
||||
{2048, 2048, 64, 64, 768},
|
||||
{2048, 2048, 128, 128, 768},
|
||||
{4096, 4096, 64, 64, 768},
|
||||
{4096, 4096, 128, 128, 768},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
@@ -3,29 +3,29 @@
|
||||
|
||||
# Implements test instances for MultipleD with xdl and wmma support.
|
||||
|
||||
add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance)
|
||||
endif()
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_gemm_add test_gemm_add.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_gemm_add_relu test_gemm_add_relu.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_gemm_add_silu test_gemm_add_silu.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_silu_wmma test_gemm_add_silu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_silu_wmma PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
|
||||
add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_fastgelu_instance)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_fastgelu_wmma test_gemm_fastgelu_wmma.cpp)
|
||||
@@ -33,16 +33,6 @@ if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_fastgelu_wmma PRIVATE utility device_gemm_fastgelu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_wmma test_gemm_add_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_wmma PRIVATE utility device_gemm_add_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_add_fastgelu_wmma test_gemm_add_add_fastgelu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_add_fastgelu_wmma PRIVATE utility device_gemm_add_add_fastgelu_instance)
|
||||
@@ -66,9 +56,4 @@ endif()
|
||||
add_gtest_executable(test_gemm_bilinear_wmma test_gemm_bilinear_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_bilinear_wmma PRIVATE utility device_gemm_bilinear_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_relu_wmma test_gemm_add_relu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_relu_wmma PRIVATE utility device_gemm_add_relu_instance)
|
||||
endif()
|
||||
@@ -26,7 +26,9 @@ class TestGemmAdd : public TestGemmD0Common<Tuple>
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, Row, Row>,
|
||||
std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAdd, KernelTypes);
|
||||
TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); }
|
||||
TYPED_TEST(TestGemmAdd, Test_BF16FP16_FP16FP16_INT8) { this->Run(); }
|
||||
@@ -26,10 +26,12 @@ class TestGemmAddFastgelu : public TestGemmD0Common<Tuple>
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Row, Col, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Col, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Col, Col, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddFastgelu, Test_FP16FP16) { this->Run(); }
|
||||
TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16_FP16FP16_INT8) { this->Run(); }
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_fastgelu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddFastgelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_fastgelu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16) { this->Run(); }
|
||||
@@ -27,7 +27,9 @@ class TestGemmAddRelu : public TestGemmD0Common<Tuple>
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddRelu, Test_BF16FP16_INT8) { this->Run(); }
|
||||
TYPED_TEST(TestGemmAddRelu, Test_BF16FP16_FP16FP16_INT8) { this->Run(); }
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_relu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddRelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_relu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddRelu, Test_BF16FP16) { this->Run(); }
|
||||
@@ -27,7 +27,9 @@ class TestGemmAddSilu : public TestGemmD0Common<Tuple>
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>,
|
||||
std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_INT8) { this->Run(); }
|
||||
TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_BF16FP16_INT8) { this->Run(); }
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_silu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddSilu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_silu_impl<
|
||||
typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, BF16, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_BF16FP16) { this->Run(); }
|
||||
@@ -1,32 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAdd : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
using ProfileCall = typename TestGemmD0Common<Tuple>::ProfileCall;
|
||||
|
||||
ProfileCall GetImpl() override
|
||||
{
|
||||
return ck::profiler::profile_gemm_add_impl<typename TestGemmD0Common<Tuple>::ADataType,
|
||||
typename TestGemmD0Common<Tuple>::BDataType,
|
||||
typename TestGemmD0Common<Tuple>::AccDataType,
|
||||
typename TestGemmD0Common<Tuple>::D0DataType,
|
||||
typename TestGemmD0Common<Tuple>::EDataType,
|
||||
typename TestGemmD0Common<Tuple>::ALayout,
|
||||
typename TestGemmD0Common<Tuple>::BLayout,
|
||||
typename TestGemmD0Common<Tuple>::D0Layout,
|
||||
typename TestGemmD0Common<Tuple>::ELayout>;
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAdd, KernelTypes);
|
||||
TYPED_TEST(TestGemmAdd, Test_BF16FP16_INT8) { this->Run(); }
|
||||
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_gtest_executable(test_gemm_b_scale_xdl test_gemm_b_scale_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_b_scale_xdl PRIVATE utility device_gemm_b_scale_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_b_scale_wmma test_gemm_b_scale_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_b_scale_wmma PRIVATE utility device_gemm_b_scale_instance)
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_gemm_b_scale test_gemm_b_scale.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_b_scale PRIVATE utility device_gemm_b_scale_instance)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_b_scale_util.hpp"
|
||||
|
||||
using I4 = ck::pk_i4_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBScale_MK_NK
|
||||
: public ck::test::TestGemmBScale<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, I4, F16, F16, F16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmBScale_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
#include "test_gemm_b_scale_ut_cases.inc"
|
||||
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_gtest_executable(test_gemm_multi_abd_wmma test_gemm_multi_abd_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multi_abd_wmma PRIVATE utility device_gemm_multi_abd_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_multi_abd_xdl test_gemm_multi_abd_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multi_abd_xdl PRIVATE utility device_gemm_multi_abd_instance)
|
||||
endif()
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_gemm_multi_abd test_gemm_multi_abd.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multi_abd PRIVATE utility device_gemm_multi_abd_instance)
|
||||
endif()
|
||||
endif()
|
||||
@@ -1,154 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_multi_abd_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using I8 = int8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
|
||||
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
|
||||
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -1,32 +1,23 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_gtest_executable(test_gemm_universal_wmma_fp16 test_gemm_universal_wmma_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_wmma_fp16 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_wmma_bf16 test_gemm_universal_wmma_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_wmma_bf16 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_wmma_fp8 test_gemm_universal_wmma_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_wmma_fp8 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_xdl_fp16 test_gemm_universal_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_xdl_fp16 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_xdl_fp8 test_gemm_universal_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_xdl_fp8 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_universal_xdl_bf16 test_gemm_universal_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_xdl_bf16 PRIVATE utility device_gemm_universal_instance)
|
||||
add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -55,7 +55,8 @@ class TestGemmUniversal_BF16_KM_NK
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
@@ -66,11 +67,6 @@ using KernelTypes_MK_NK = ::testing::Types<
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
using KernelTypes_KM_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
using KernelTypes_KM_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
@@ -78,6 +74,12 @@ using KernelTypes_KM_NK = ::testing::Types<
|
||||
#endif
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
using KernelTypes_KM_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN);
|
||||
@@ -44,31 +44,34 @@ class TestGemmUniversal_FP8_MK_NK
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) && !defined(CK_USE_WMMA_FP8)
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
std::tuple< F8, F16, F16, F16>>;
|
||||
#elif defined(CK_USE_WMMA_FP8)
|
||||
// Fallback test type when WMMA FP8 is used
|
||||
std::tuple< F8, F8, F8, BF16>>;
|
||||
#else
|
||||
// Fallback test type when FP8 is not enabled
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
std::tuple< F16, F16, F16, F16>>;
|
||||
#endif
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) && !defined(CK_USE_WMMA_FP8)
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
#endif
|
||||
std::tuple< F8, F16, F16, F16>>;
|
||||
#elif defined(CK_USE_WMMA_FP8)
|
||||
// Fallback test type when WMMA FP8 is used
|
||||
std::tuple< F8, F8, F8, BF16>>;
|
||||
#else
|
||||
// Fallback test type when FP8 is not enabled
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
std::tuple< F16, F16, F16, F16>>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
|
||||
#include "test_gemm_universal_ut_cases_fp8.inc"
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
@@ -1,78 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
ck::index_t param_mask = 0xffff;
|
||||
ck::index_t instance_index = -1;
|
||||
#if defined(CK_USE_WMMA_FP8)
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP8_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP8_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F8, F8, F8, BF16>
|
||||
>;
|
||||
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F8, F8, F8, BF16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
#include "test_gemm_universal_ut_cases_fp8.inc"
|
||||
|
||||
#endif // CK_USE_WMMA_FP8
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
ck::index_t param_mask = 0xffff;
|
||||
ck::index_t instance_index = -1;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_KM_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_BF16_KM_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
using KernelTypes_KM_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
using KernelTypes_KM_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK);
|
||||
|
||||
#include "test_gemm_universal_ut_cases_bf16.inc"
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_universal_util.hpp"
|
||||
ck::index_t param_mask = 0xffff;
|
||||
ck::index_t instance_index = -1;
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_MK_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_MK_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_KM_KN
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversal_FP16_KM_NK
|
||||
: public ck::test::TestGemmUniversal<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
|
||||
#endif
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
|
||||
#endif
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
using KernelTypes_KM_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
|
||||
using KernelTypes_KM_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, F16, F16, F16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_NK, KernelTypes_KM_NK);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_KN, KernelTypes_KM_KN);
|
||||
|
||||
#include "test_gemm_universal_ut_cases_fp16.inc"
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -21,11 +21,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
target_compile_options(test_grouped_conv_bwd_data_scale PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_grouped_conv_bwd_data_scale PRIVATE gtest_main getopt::getopt utility device_grouped_conv3d_bwd_data_scale_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_interface_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_interface_wmma test_grouped_convnd_bwd_data_interface_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_interface_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance)
|
||||
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
|
||||
endif()
|
||||
endif()
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
@@ -32,7 +33,7 @@ static constexpr auto ConvBwdDataDefault = ConvBackwardDataSpecialization::Def
|
||||
static constexpr auto Filter1x1Stride1Pad0 = ConvBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <typename Tuple, ConvBackwardDataSpecialization ConvSpec>
|
||||
class TestGroupedConvndBwdData : public ::testing::Test
|
||||
class TestGroupedConvndBwdDataXdl : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
static constexpr ck::index_t NDimSpatial = 2;
|
||||
@@ -119,6 +120,100 @@ class TestGroupedConvndBwdData : public ::testing::Test
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple, ConvBackwardDataSpecialization ConvSpec>
|
||||
class TestGroupedConvndBwdDataWmma : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
static constexpr ck::index_t NDimSpatial = 2;
|
||||
|
||||
using OutLayout = std::tuple_element_t<0, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<1, Tuple>;
|
||||
using InLayout = std::tuple_element_t<2, Tuple>;
|
||||
|
||||
// clang-format off
|
||||
using GroupedConvBwdDataDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
//| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial,OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>;
|
||||
// clang-format on
|
||||
|
||||
ck::utils::conv::ConvParam conv_param;
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if(!ck::is_gfx11_supported())
|
||||
{
|
||||
GTEST_SKIP();
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
bool Run()
|
||||
{
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> out_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> out_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> wei_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> wei_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> in_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> in_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), out_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), out_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides);
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), in_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), in_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
auto conv = GroupedConvBwdDataDeviceInstance{};
|
||||
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
std::array<const void*, 0>{},
|
||||
nullptr,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
Pass{},
|
||||
Pass{},
|
||||
Pass{});
|
||||
return conv.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
|
||||
@@ -131,20 +226,35 @@ using KernelTypes =
|
||||
::testing::Types<std::tuple<GNHWK, GKYXC, GNHWC>, std::tuple<NHWGK, GKYXC, NHWGC>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataDefault : public TestGroupedConvndBwdData<Tuple, ConvBwdDataDefault>
|
||||
class TestGroupedConvndBwdDataDefaultXdl
|
||||
: public TestGroupedConvndBwdDataXdl<Tuple, ConvBwdDataDefault>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataFilter1x1
|
||||
: public TestGroupedConvndBwdData<Tuple, Filter1x1Stride1Pad0>
|
||||
class TestGroupedConvndBwdDataFilter1x1Xdl
|
||||
: public TestGroupedConvndBwdDataXdl<Tuple, Filter1x1Stride1Pad0>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefault, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1, KernelTypes);
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataDefaultWmma
|
||||
: public TestGroupedConvndBwdDataWmma<Tuple, ConvBwdDataDefault>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck)
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataFilter1x1Wmma
|
||||
: public TestGroupedConvndBwdDataWmma<Tuple, Filter1x1Stride1Pad0>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefaultXdl, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1Xdl, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefaultWmma, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1Wmma, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataFilter1x1Xdl, SpecializationCheckXdl)
|
||||
{
|
||||
// Check filter 3,3 instead of 1,1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
@@ -167,7 +277,30 @@ TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck)
|
||||
EXPECT_TRUE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck)
|
||||
TYPED_TEST(TestGroupedConvndBwdDataFilter1x1Wmma, SpecializationCheckWmma)
|
||||
{
|
||||
// Check filter 3,3 instead of 1,1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check strides 2,2 instead of 1,1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check with pad
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Supported version
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_TRUE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataDefaultXdl, VectorLoadCheckXdl)
|
||||
{
|
||||
// vector load for A
|
||||
this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
@@ -179,7 +312,19 @@ TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck)
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataDefault, SplitK)
|
||||
TYPED_TEST(TestGroupedConvndBwdDataDefaultWmma, VectorLoadCheckWmma)
|
||||
{
|
||||
// vector load for A
|
||||
this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
// vector load for B, E, Ds
|
||||
this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataDefaultXdl, SplitK)
|
||||
{
|
||||
if(ck::is_xdl_supported())
|
||||
{
|
||||
@@ -1,186 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using DataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
using ConvBackwardDataSpecialization =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
|
||||
static constexpr auto ConvBwdDataDefault = ConvBackwardDataSpecialization::Default;
|
||||
static constexpr auto Filter1x1Stride1Pad0 = ConvBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <typename Tuple, ConvBackwardDataSpecialization ConvSpec>
|
||||
class TestGroupedConvndBwdData : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
static constexpr ck::index_t NDimSpatial = 2;
|
||||
|
||||
using OutLayout = std::tuple_element_t<0, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<1, Tuple>;
|
||||
using InLayout = std::tuple_element_t<2, Tuple>;
|
||||
|
||||
// clang-format off
|
||||
using GroupedConvBwdDataDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
//| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial,OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>;
|
||||
// clang-format on
|
||||
|
||||
ck::utils::conv::ConvParam conv_param;
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
if(!ck::is_gfx11_supported())
|
||||
{
|
||||
GTEST_SKIP();
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
bool Run()
|
||||
{
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> out_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> out_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> wei_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> wei_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> in_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> in_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), out_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), out_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides);
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), in_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), in_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
auto conv = GroupedConvBwdDataDeviceInstance{};
|
||||
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
std::array<const void*, 0>{},
|
||||
nullptr,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
Pass{},
|
||||
Pass{},
|
||||
Pass{});
|
||||
return conv.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
|
||||
using KernelTypes =
|
||||
::testing::Types<std::tuple<GNHWK, GKYXC, GNHWC>, std::tuple<NHWGK, GKYXC, NHWGC>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataDefault : public TestGroupedConvndBwdData<Tuple, ConvBwdDataDefault>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataFilter1x1
|
||||
: public TestGroupedConvndBwdData<Tuple, Filter1x1Stride1Pad0>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefault, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck)
|
||||
{
|
||||
// Check filter 3,3 instead of 1,1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check strides 2,2 instead of 1,1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check with pad
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Supported version
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_TRUE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck)
|
||||
{
|
||||
// vector load for A
|
||||
this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
// vector load for B, E, Ds
|
||||
this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
Reference in New Issue
Block a user