diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 926fafcc97..02f42d4427 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -1,12 +1,12 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_gtest_executable(test_batched_gemm_xdl test_batched_gemm_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_batched_gemm_xdl PRIVATE utility device_batched_gemm_instance) -endif() - -add_gtest_executable(test_batched_gemm_wmma test_batched_gemm_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_batched_gemm_wmma PRIVATE utility device_batched_gemm_instance) +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_batched_gemm test_batched_gemm.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance) + endif() endif() diff --git a/test/batched_gemm/test_batched_gemm_xdl.cpp b/test/batched_gemm/test_batched_gemm.cpp similarity index 96% rename from test/batched_gemm/test_batched_gemm_xdl.cpp rename to test/batched_gemm/test_batched_gemm.cpp index 88170f9909..82068b4170 100644 --- a/test/batched_gemm/test_batched_gemm_xdl.cpp +++ b/test/batched_gemm/test_batched_gemm.cpp @@ -214,6 +214,13 @@ TEST_F(TestBatchedGemm, bf16) this->params.push_back({68, 68, 68, 2}); this->params.push_back({40, 40, 40, 2}); this->params.push_back({256, 256, 128, 3}); + + // Tests with larger MNK + this->params.push_back({512, 256, 128, 1}); + this->params.push_back({256, 240, 192, 2}); + this->params.push_back({256, 256, 128, 3}); + this->params.push_back({240, 128, 128, 5}); + this->template Run(); } #endif @@ -226,7 +233,13 @@ TEST_F(TestBatchedGemm, fp16) this->params.push_back({60, 60, 60, 2}); this->params.push_back({68, 68, 68, 2}); this->params.push_back({40, 40, 40, 2}); + + // Tests with larger MNK + this->params.push_back({512, 256, 128, 1}); + this->params.push_back({256, 240, 192, 2}); this->params.push_back({256, 256, 128, 3}); + this->params.push_back({240, 128, 128, 5}); + this->template Run(); } #endif diff --git a/test/batched_gemm/test_batched_gemm_wmma.cpp b/test/batched_gemm/test_batched_gemm_wmma.cpp deleted file mode 100644 index db751cf7d1..0000000000 --- a/test/batched_gemm/test_batched_gemm_wmma.cpp +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include - -#include - -#include "profiler/profile_batched_gemm_impl.hpp" - -#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" -static ck::index_t param_mask = 0xffff; -static ck::index_t instance_index = -1; -struct GemmParams -{ - ck::index_t M; - ck::index_t N; - ck::index_t K; - ck::index_t BatchCount; -}; - -class TestBatchedGemm : public ::testing::Test -{ - protected: - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - std::vector params; - - template - void Run() - { - using namespace ck::tensor_operation::device; - - bool pass = true; - for(size_t i = 0; i < params.size(); i++) - { - if((param_mask & (1 << i)) == 0) - { - continue; - } - auto& param = params[i]; - const auto M = param.M; - const auto N = param.N; - const auto K = param.K; - const auto BatchCount = param.BatchCount; - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, - 1, - false, - 1, - M, - N, - K, - K, - N, - N, - M * K, - K * N, - M * N, - BatchCount, - instance_index); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, - 1, - false, - 1, - M, - N, - K, - K, - K, - N, - M * K, - K * N, - M * N, - BatchCount, - instance_index); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, - 1, - false, - 1, - M, - N, - K, - M, - N, - N, - M * K, - K * N, - M * N, - BatchCount, - instance_index); - - pass = pass && ck::profiler::profile_batched_gemm_impl>( - true, - 1, - false, - 1, - M, - N, - K, - M, - K, - N, - M * K, - K * N, - M * N, - BatchCount, - instance_index); - } - EXPECT_TRUE(pass); - } -}; - -// #ifdef CK_ENABLE_INT8 -// TEST_F(TestBatchedGemm, i8) -// { -// this->params.push_back({64, 64, 64, 2}); -// this->params.push_back({64, 64, 64, 1}); -// this->params.push_back({60, 60, 60, 2}); -// this->params.push_back({68, 68, 68, 2}); -// this->params.push_back({40, 40, 40, 2}); -// this->params.push_back({256, 256, 128, 3}); -// this->template Run(); -// } -// #endif - -#ifdef CK_ENABLE_BF16 -TEST_F(TestBatchedGemm, bf16) -{ - this->params.push_back({64, 64, 64, 2}); - this->params.push_back({64, 64, 64, 1}); - this->params.push_back({40, 40, 40, 2}); - this->params.push_back({256, 256, 128, 3}); - - // Tests with larger MNK - this->params.push_back({512, 256, 128, 1}); - this->params.push_back({256, 240, 192, 2}); - this->params.push_back({256, 256, 128, 3}); - this->params.push_back({240, 128, 128, 5}); - this->template Run(); -} -#endif - -#ifdef CK_ENABLE_FP16 -TEST_F(TestBatchedGemm, fp16) -{ - this->params.push_back({64, 64, 64, 2}); - this->params.push_back({64, 64, 64, 1}); - this->params.push_back({40, 40, 40, 2}); - this->params.push_back({256, 256, 128, 3}); - - // Tests with larger MNK - this->params.push_back({512, 256, 128, 1}); - this->params.push_back({256, 240, 192, 2}); - this->params.push_back({256, 256, 128, 3}); - this->params.push_back({240, 128, 128, 5}); - this->template Run(); -} -#endif - -// #ifdef CK_ENABLE_FP32 -// TEST_F(TestBatchedGemm, fp32) -// { -// this->params.push_back({64, 64, 64, 2}); -// this->params.push_back({64, 64, 64, 1}); -// this->params.push_back({60, 60, 60, 2}); -// this->params.push_back({68, 68, 68, 2}); -// this->params.push_back({40, 40, 40, 2}); -// this->params.push_back({256, 256, 128, 3}); -// this->template Run(); -// } -// #endif - -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) {} - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -} diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index a12d5c3435..a66b011b33 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -1,17 +1,17 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_gtest_executable(test_batched_gemm_gemm_fp16_xdl test_batched_gemm_gemm_fp16_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_batched_gemm_gemm_fp16_xdl PRIVATE utility device_batched_gemm_gemm_instance) +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) + endif() endif() -add_gtest_executable(test_batched_gemm_gemm_bf16_wmma test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp) +add_gtest_executable(test_batched_gemm_gemm_bf16_wmma_cshuffle_v3 test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp) if(result EQUAL 0) - target_link_libraries(test_batched_gemm_gemm_bf16_wmma PRIVATE utility device_batched_gemm_gemm_instance) -endif() - -add_gtest_executable(test_batched_gemm_gemm_fp16_wmma test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp) -if(result EQUAL 0) - target_link_libraries(test_batched_gemm_gemm_fp16_wmma PRIVATE utility device_batched_gemm_gemm_instance) + target_link_libraries(test_batched_gemm_gemm_bf16_wmma_cshuffle_v3 PRIVATE utility device_batched_gemm_gemm_instance) endif() diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp similarity index 94% rename from test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp rename to test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp index 011e53a99a..8d6405e618 100644 --- a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp +++ b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp @@ -136,13 +136,20 @@ using KernelTypes = ::testing::Types< TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes); -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16) { this->Run(); } +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16) +{ + this->bench_ = false; + this->verify_ = true; + this->Run(); +} TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM) { this->lengths_ = std::vector>{ {136, 128, 32, 128, 1}, }; + this->bench_ = false; + this->verify_ = true; this->Run(); } @@ -151,6 +158,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN) this->lengths_ = std::vector>{ {128, 136, 32, 128, 1}, }; + this->bench_ = false; + this->verify_ = true; this->Run(); } @@ -160,6 +169,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK) {128, 128, 40, 128, 1}, {128, 128, 136, 128, 1}, }; + this->bench_ = false; + this->verify_ = true; this->Run(); } @@ -168,6 +179,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO) this->lengths_ = std::vector>{ {128, 128, 32, 136, 1}, }; + this->bench_ = false; + this->verify_ = true; this->Run(); } @@ -176,6 +189,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM) this->lengths_ = std::vector>{ {129, 128, 32, 128, 1}, }; + this->bench_ = false; + this->verify_ = true; this->Run(); } @@ -184,6 +199,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN) this->lengths_ = std::vector>{ {128, 129, 32, 128, 1}, }; + this->bench_ = false; + this->verify_ = true; this->Run(); } @@ -193,6 +210,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK) {128, 128, 33, 128, 1}, {128, 128, 129, 128, 1}, }; + this->bench_ = false; + this->verify_ = true; this->Run(); } @@ -202,6 +221,8 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO) this->lengths_ = std::vector>{ {128, 128, 32, 129, 1}, }; + this->bench_ = false; + this->verify_ = true; this->Run(); } diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp deleted file mode 100644 index da97a95f4e..0000000000 --- a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "gtest/gtest.h" -#include "test_batched_gemm_gemm_util.hpp" - -template -class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm -{ -}; - -// clang-format off -using KernelTypes = ::testing::Types< - std::tuple, - std::tuple - >; -// clang-format on - -TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes); - -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16) -{ - this->bench_ = true; - this->verify_ = true; - this->Run(); -} - -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM) -{ - this->lengths_ = std::vector>{ - {136, 128, 32, 128, 1}, - }; - this->bench_ = true; - this->verify_ = true; - this->Run(); -} - -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN) -{ - this->lengths_ = std::vector>{ - {128, 136, 32, 128, 1}, - }; - this->bench_ = true; - this->verify_ = true; - this->Run(); -} - -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK) -{ - this->lengths_ = std::vector>{ - {128, 128, 40, 128, 1}, - {128, 128, 136, 128, 1}, - }; - this->bench_ = true; - this->verify_ = true; - this->Run(); -} - -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO) -{ - this->lengths_ = std::vector>{ - {128, 128, 32, 136, 1}, - }; - this->bench_ = true; - this->verify_ = true; - this->Run(); -} - -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM) -{ - this->lengths_ = std::vector>{ - {129, 128, 32, 128, 1}, - }; - this->bench_ = true; - this->verify_ = true; - this->Run(); -} - -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN) -{ - this->lengths_ = std::vector>{ - {128, 129, 32, 128, 1}, - }; - this->bench_ = true; - this->verify_ = true; - this->Run(); -} - -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK) -{ - this->lengths_ = std::vector>{ - {128, 128, 33, 128, 1}, - {128, 128, 129, 128, 1}, - }; - this->bench_ = true; - this->verify_ = true; - this->Run(); -} - -// If kernel B1Layout is RowMajor, expect not to support odd O size -TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO) -{ - this->lengths_ = std::vector>{ - {128, 128, 32, 129, 1}, - }; - this->bench_ = true; - this->verify_ = true; - this->Run(); -} - -TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16) -{ - this->lengths_ = std::vector>{ - {256, 256, 64, 64, 768}, - {256, 256, 128, 128, 768}, - {512, 512, 64, 64, 768}, - {512, 512, 128, 128, 768}, - {1024, 1024, 64, 64, 768}, - {1024, 1024, 128, 128, 768}, - {2048, 2048, 64, 64, 768}, - {2048, 2048, 128, 128, 768}, - {4096, 4096, 64, 64, 768}, - {4096, 4096, 128, 128, 768}, - }; - this->bench_ = true; - this->verify_ = false; - this->Run(); -} diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index 17bfadf95d..4cfd6abdc1 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -3,29 +3,29 @@ # Implements test instances for MultipleD with xdl and wmma support. -add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance) -endif() +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_gemm_add test_gemm_add.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) + endif() -add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) -endif() + add_gtest_executable(test_gemm_add_relu test_gemm_add_relu.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) + endif() -add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) -endif() + add_gtest_executable(test_gemm_add_silu test_gemm_add_silu.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) + endif() -add_gtest_executable(test_gemm_add_silu_wmma test_gemm_add_silu_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_add_silu_wmma PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) -endif() - -add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) + add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_fastgelu_instance) + endif() endif() add_gtest_executable(test_gemm_fastgelu_wmma test_gemm_fastgelu_wmma.cpp) @@ -33,16 +33,6 @@ if(result EQUAL 0) target_link_libraries(test_gemm_fastgelu_wmma PRIVATE utility device_gemm_fastgelu_instance) endif() -add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance) -endif() - -add_gtest_executable(test_gemm_add_wmma test_gemm_add_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_add_wmma PRIVATE utility device_gemm_add_instance) -endif() - add_gtest_executable(test_gemm_add_add_fastgelu_wmma test_gemm_add_add_fastgelu_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_add_add_fastgelu_wmma PRIVATE utility device_gemm_add_add_fastgelu_instance) @@ -66,9 +56,4 @@ endif() add_gtest_executable(test_gemm_bilinear_wmma test_gemm_bilinear_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_bilinear_wmma PRIVATE utility device_gemm_bilinear_instance) -endif() - -add_gtest_executable(test_gemm_add_relu_wmma test_gemm_add_relu_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_add_relu_wmma PRIVATE utility device_gemm_add_relu_instance) endif() \ No newline at end of file diff --git a/test/gemm_add/test_gemm_add_wmma.cpp b/test/gemm_add/test_gemm_add.cpp similarity index 83% rename from test/gemm_add/test_gemm_add_wmma.cpp rename to test/gemm_add/test_gemm_add.cpp index bc440a9ae8..61f4372c76 100644 --- a/test/gemm_add/test_gemm_add_wmma.cpp +++ b/test/gemm_add/test_gemm_add.cpp @@ -26,7 +26,9 @@ class TestGemmAdd : public TestGemmD0Common }; using KernelTypes = ::testing::Types, - std::tuple>; + std::tuple, + std::tuple, + std::tuple>; TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); -TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); } +TYPED_TEST(TestGemmAdd, Test_BF16FP16_FP16FP16_INT8) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp b/test/gemm_add/test_gemm_add_fastgelu.cpp similarity index 80% rename from test/gemm_add/test_gemm_add_fastgelu_wmma.cpp rename to test/gemm_add/test_gemm_add_fastgelu.cpp index e72b1c3761..2e3fe24e3c 100644 --- a/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu.cpp @@ -26,10 +26,12 @@ class TestGemmAddFastgelu : public TestGemmD0Common } }; -using KernelTypes = ::testing::Types, +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, std::tuple, std::tuple, std::tuple>; TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes); -TYPED_TEST(TestGemmAddFastgelu, Test_FP16FP16) { this->Run(); } +TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16_FP16FP16_INT8) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp deleted file mode 100644 index 21c5b47f88..0000000000 --- a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "gtest/gtest.h" -#include "ck/ck.hpp" -#include "profiler/profile_gemm_add_fastgelu_impl.hpp" -#include "test_gemm_common.hpp" - -template -class TestGemmAddFastgelu : public TestGemmD0Common -{ - using ProfileCall = typename TestGemmD0Common::ProfileCall; - - ProfileCall GetImpl() override - { - return ck::profiler::profile_gemm_add_fastgelu_impl< - typename TestGemmD0Common::ADataType, - typename TestGemmD0Common::BDataType, - typename TestGemmD0Common::AccDataType, - typename TestGemmD0Common::D0DataType, - typename TestGemmD0Common::EDataType, - typename TestGemmD0Common::ALayout, - typename TestGemmD0Common::BLayout, - typename TestGemmD0Common::D0Layout, - typename TestGemmD0Common::ELayout>; - } -}; - -using KernelTypes = ::testing::Types, - std::tuple>; - -TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes); -TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_relu_xdl.cpp b/test/gemm_add/test_gemm_add_relu.cpp similarity index 80% rename from test/gemm_add/test_gemm_add_relu_xdl.cpp rename to test/gemm_add/test_gemm_add_relu.cpp index d87ac74188..649017ddcb 100644 --- a/test/gemm_add/test_gemm_add_relu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_relu.cpp @@ -27,7 +27,9 @@ class TestGemmAddRelu : public TestGemmD0Common }; using KernelTypes = ::testing::Types, - std::tuple>; + std::tuple, + std::tuple, + std::tuple>; TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes); -TYPED_TEST(TestGemmAddRelu, Test_BF16FP16_INT8) { this->Run(); } +TYPED_TEST(TestGemmAddRelu, Test_BF16FP16_FP16FP16_INT8) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_relu_wmma.cpp b/test/gemm_add/test_gemm_add_relu_wmma.cpp deleted file mode 100644 index 1d099f8bcf..0000000000 --- a/test/gemm_add/test_gemm_add_relu_wmma.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "gtest/gtest.h" -#include "ck/ck.hpp" -#include "profiler/profile_gemm_add_relu_impl.hpp" -#include "test_gemm_common.hpp" - -template -class TestGemmAddRelu : public TestGemmD0Common -{ - using ProfileCall = typename TestGemmD0Common::ProfileCall; - - ProfileCall GetImpl() override - { - return ck::profiler::profile_gemm_add_relu_impl< - typename TestGemmD0Common::ADataType, - typename TestGemmD0Common::BDataType, - typename TestGemmD0Common::AccDataType, - typename TestGemmD0Common::D0DataType, - typename TestGemmD0Common::EDataType, - typename TestGemmD0Common::ALayout, - typename TestGemmD0Common::BLayout, - typename TestGemmD0Common::D0Layout, - typename TestGemmD0Common::ELayout>; - } -}; - -using KernelTypes = ::testing::Types, - std::tuple>; - -TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes); -TYPED_TEST(TestGemmAddRelu, Test_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_silu_xdl.cpp b/test/gemm_add/test_gemm_add_silu.cpp similarity index 80% rename from test/gemm_add/test_gemm_add_silu_xdl.cpp rename to test/gemm_add/test_gemm_add_silu.cpp index 3af279c286..64b51b8b1b 100644 --- a/test/gemm_add/test_gemm_add_silu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_silu.cpp @@ -27,7 +27,9 @@ class TestGemmAddSilu : public TestGemmD0Common }; using KernelTypes = ::testing::Types, - std::tuple>; + std::tuple, + std::tuple, + std::tuple>; TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes); -TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_INT8) { this->Run(); } +TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_BF16FP16_INT8) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_silu_wmma.cpp b/test/gemm_add/test_gemm_add_silu_wmma.cpp deleted file mode 100644 index f68f67a36f..0000000000 --- a/test/gemm_add/test_gemm_add_silu_wmma.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "gtest/gtest.h" -#include "ck/ck.hpp" -#include "profiler/profile_gemm_add_silu_impl.hpp" -#include "test_gemm_common.hpp" - -template -class TestGemmAddSilu : public TestGemmD0Common -{ - using ProfileCall = typename TestGemmD0Common::ProfileCall; - - ProfileCall GetImpl() override - { - return ck::profiler::profile_gemm_add_silu_impl< - typename TestGemmD0Common::ADataType, - typename TestGemmD0Common::BDataType, - typename TestGemmD0Common::AccDataType, - typename TestGemmD0Common::D0DataType, - typename TestGemmD0Common::EDataType, - typename TestGemmD0Common::ALayout, - typename TestGemmD0Common::BLayout, - typename TestGemmD0Common::D0Layout, - typename TestGemmD0Common::ELayout>; - } -}; - -using KernelTypes = ::testing::Types, - std::tuple>; - -TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes); -TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_xdl.cpp b/test/gemm_add/test_gemm_add_xdl.cpp deleted file mode 100644 index 873e87edd4..0000000000 --- a/test/gemm_add/test_gemm_add_xdl.cpp +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "gtest/gtest.h" -#include "ck/ck.hpp" -#include "profiler/profile_gemm_add_impl.hpp" -#include "test_gemm_common.hpp" - -template -class TestGemmAdd : public TestGemmD0Common -{ - using ProfileCall = typename TestGemmD0Common::ProfileCall; - - ProfileCall GetImpl() override - { - return ck::profiler::profile_gemm_add_impl::ADataType, - typename TestGemmD0Common::BDataType, - typename TestGemmD0Common::AccDataType, - typename TestGemmD0Common::D0DataType, - typename TestGemmD0Common::EDataType, - typename TestGemmD0Common::ALayout, - typename TestGemmD0Common::BLayout, - typename TestGemmD0Common::D0Layout, - typename TestGemmD0Common::ELayout>; - } -}; - -using KernelTypes = ::testing::Types, - std::tuple>; - -TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); -TYPED_TEST(TestGemmAdd, Test_BF16FP16_INT8) { this->Run(); } diff --git a/test/gemm_b_scale/CMakeLists.txt b/test/gemm_b_scale/CMakeLists.txt index 517e2f01f6..b386ec67df 100644 --- a/test/gemm_b_scale/CMakeLists.txt +++ b/test/gemm_b_scale/CMakeLists.txt @@ -1,12 +1,12 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_gtest_executable(test_gemm_b_scale_xdl test_gemm_b_scale_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_b_scale_xdl PRIVATE utility device_gemm_b_scale_instance) -endif() - -add_gtest_executable(test_gemm_b_scale_wmma test_gemm_b_scale_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_b_scale_wmma PRIVATE utility device_gemm_b_scale_instance) +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_gemm_b_scale test_gemm_b_scale.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_b_scale PRIVATE utility device_gemm_b_scale_instance) + endif() endif() diff --git a/test/gemm_b_scale/test_gemm_b_scale_wmma.cpp b/test/gemm_b_scale/test_gemm_b_scale.cpp similarity index 100% rename from test/gemm_b_scale/test_gemm_b_scale_wmma.cpp rename to test/gemm_b_scale/test_gemm_b_scale.cpp diff --git a/test/gemm_b_scale/test_gemm_b_scale_xdl.cpp b/test/gemm_b_scale/test_gemm_b_scale_xdl.cpp deleted file mode 100644 index 93eb128bb0..0000000000 --- a/test/gemm_b_scale/test_gemm_b_scale_xdl.cpp +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "gtest/gtest.h" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "test_gemm_b_scale_util.hpp" - -using I4 = ck::pk_i4_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -namespace { - -template -struct tuple_concat; - -template -struct tuple_concat, std::tuple> -{ - using type = std::tuple; -}; - -} // namespace - -template -class TestGemmBScale_MK_NK - : public ck::test::TestGemmBScale, Tuple>::type> -{ -}; - -// clang-format off -using KernelTypes_MK_NK = ::testing::Types< - // ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType - std::tuple< F16, I4, F16, F16, F16> - >; -// clang-format on - -TYPED_TEST_SUITE(TestGemmBScale_MK_NK, KernelTypes_MK_NK); - -#include "test_gemm_b_scale_ut_cases.inc" diff --git a/test/gemm_multi_abd/CMakeLists.txt b/test/gemm_multi_abd/CMakeLists.txt index 9b1454ca93..2e327454f2 100644 --- a/test/gemm_multi_abd/CMakeLists.txt +++ b/test/gemm_multi_abd/CMakeLists.txt @@ -1,12 +1,12 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_gtest_executable(test_gemm_multi_abd_wmma test_gemm_multi_abd_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_multi_abd_wmma PRIVATE utility device_gemm_multi_abd_instance) -endif() - -add_gtest_executable(test_gemm_multi_abd_xdl test_gemm_multi_abd_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_multi_abd_xdl PRIVATE utility device_gemm_multi_abd_instance) -endif() +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_gemm_multi_abd test_gemm_multi_abd.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_multi_abd PRIVATE utility device_gemm_multi_abd_instance) + endif() +endif() \ No newline at end of file diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp b/test/gemm_multi_abd/test_gemm_multi_abd.cpp similarity index 100% rename from test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp rename to test/gemm_multi_abd/test_gemm_multi_abd.cpp diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp b/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp deleted file mode 100644 index ed3fbbf087..0000000000 --- a/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "gtest/gtest.h" -#include "ck/ck.hpp" -#include "profiler/profile_gemm_multi_abd_impl.hpp" -#include "test_gemm_common.hpp" - -namespace ck { -namespace test { - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using I8 = int8_t; -using BF16 = ck::bhalf_t; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Multiply = ck::tensor_operation::element_wise::Multiply; -using Add = ck::tensor_operation::element_wise::Add; -using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; -using FastGelu = ck::tensor_operation::element_wise::FastGelu; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; -using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; -using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; - -using KernelTypesABD = ::testing::Types, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - Add>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - Add>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>, - std::tuple, - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAddFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAdd>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - Multiply>>; - -TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD); -TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); } - -} // namespace test -} // namespace ck diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt index 5be42aae90..d48343797a 100644 --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,32 +1,23 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_gtest_executable(test_gemm_universal_wmma_fp16 test_gemm_universal_wmma_fp16.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_wmma_fp16 PRIVATE utility device_gemm_universal_instance) -endif() -add_gtest_executable(test_gemm_universal_wmma_bf16 test_gemm_universal_wmma_bf16.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_wmma_bf16 PRIVATE utility device_gemm_universal_instance) -endif() +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance) + endif() -add_gtest_executable(test_gemm_universal_wmma_fp8 test_gemm_universal_wmma_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_wmma_fp8 PRIVATE utility device_gemm_universal_instance) -endif() + add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance) + endif() -add_gtest_executable(test_gemm_universal_xdl_fp16 test_gemm_universal_xdl_fp16.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_xdl_fp16 PRIVATE utility device_gemm_universal_instance) -endif() - -add_gtest_executable(test_gemm_universal_xdl_fp8 test_gemm_universal_xdl_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_xdl_fp8 PRIVATE utility device_gemm_universal_instance) -endif() - -add_gtest_executable(test_gemm_universal_xdl_bf16 test_gemm_universal_xdl_bf16.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_xdl_bf16 PRIVATE utility device_gemm_universal_instance) + add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_bf16.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance) + endif() endif() diff --git a/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp b/test/gemm_universal/test_gemm_universal_bf16.cpp similarity index 95% rename from test/gemm_universal/test_gemm_universal_wmma_bf16.cpp rename to test/gemm_universal/test_gemm_universal_bf16.cpp index e9f25df162..a4306e6916 100644 --- a/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp +++ b/test/gemm_universal/test_gemm_universal_bf16.cpp @@ -55,7 +55,8 @@ class TestGemmUniversal_BF16_KM_NK // clang-format off using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< BF16, BF16, BF16, BF16> + + std::tuple< BF16, BF16, BF16, BF16> >; using KernelTypes_MK_NK = ::testing::Types< @@ -66,11 +67,6 @@ using KernelTypes_MK_NK = ::testing::Types< std::tuple< BF16, BF16, BF16, BF16> >; -using KernelTypes_KM_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< BF16, BF16, BF16, BF16> - >; - using KernelTypes_KM_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType #if defined(CK_ENABLE_FP8) @@ -78,6 +74,12 @@ using KernelTypes_KM_NK = ::testing::Types< #endif std::tuple< BF16, BF16, BF16, BF16> >; + +using KernelTypes_KM_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + // clang-format on TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); diff --git a/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp b/test/gemm_universal/test_gemm_universal_fp16.cpp similarity index 100% rename from test/gemm_universal/test_gemm_universal_wmma_fp16.cpp rename to test/gemm_universal/test_gemm_universal_fp16.cpp diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp b/test/gemm_universal/test_gemm_universal_fp8.cpp similarity index 76% rename from test/gemm_universal/test_gemm_universal_xdl_fp8.cpp rename to test/gemm_universal/test_gemm_universal_fp8.cpp index 49a0670528..636305c96a 100644 --- a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp +++ b/test/gemm_universal/test_gemm_universal_fp8.cpp @@ -44,31 +44,34 @@ class TestGemmUniversal_FP8_MK_NK // clang-format off using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) && !defined(CK_USE_WMMA_FP8) std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< F8, F16, F16, F16>>; +#elif defined(CK_USE_WMMA_FP8) + // Fallback test type when WMMA FP8 is used + std::tuple< F8, F8, F8, BF16>>; +#else // Fallback test type when FP8 is not enabled - std::tuple< F16, F16, F16, F16> - >; + std::tuple< F16, F16, F16, F16>>; +#endif using KernelTypes_MK_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) && !defined(CK_USE_WMMA_FP8) std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< F8, F16, F16, F16>>; +#elif defined(CK_USE_WMMA_FP8) + // Fallback test type when WMMA FP8 is used + std::tuple< F8, F8, F8, BF16>>; +#else // Fallback test type when FP8 is not enabled - std::tuple< F16, F16, F16, F16> - >; - + std::tuple< F16, F16, F16, F16>>; +#endif +// clang-format on TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); - #include "test_gemm_universal_ut_cases_fp8.inc" int main(int argc, char** argv) { diff --git a/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp b/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp deleted file mode 100644 index 5d54144747..0000000000 --- a/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "gtest/gtest.h" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "test_gemm_universal_util.hpp" -ck::index_t param_mask = 0xffff; -ck::index_t instance_index = -1; -#if defined(CK_USE_WMMA_FP8) - -using F8 = ck::f8_t; -using BF16 = ck::bhalf_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -namespace { - -template -struct tuple_concat; - -template -struct tuple_concat, std::tuple> -{ - using type = std::tuple; -}; - -} // namespace - -template -class TestGemmUniversal_FP8_MK_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP8_MK_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -// clang-format off -using KernelTypes_MK_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F8, F8, F8, BF16> - >; - -using KernelTypes_MK_NK = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F8, F8, F8, BF16> - >; -// clang-format on - -TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); - -#include "test_gemm_universal_ut_cases_fp8.inc" - -#endif // CK_USE_WMMA_FP8 -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) {} - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -} diff --git a/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp b/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp deleted file mode 100644 index 18031cd762..0000000000 --- a/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "gtest/gtest.h" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "test_gemm_universal_util.hpp" -ck::index_t param_mask = 0xffff; -ck::index_t instance_index = -1; -using BF16 = ck::bhalf_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -namespace { - -template -struct tuple_concat; - -template -struct tuple_concat, std::tuple> -{ - using type = std::tuple; -}; - -} // namespace - -template -class TestGemmUniversal_BF16_MK_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_BF16_MK_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_BF16_KM_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_BF16_KM_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -// clang-format off -using KernelTypes_MK_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - - std::tuple< BF16, BF16, BF16, BF16> - >; -using KernelTypes_MK_NK = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - - std::tuple< BF16, BF16, BF16, BF16> - >; - -using KernelTypes_KM_NK = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< BF16, BF16, BF16, BF16> - >; - -using KernelTypes_KM_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< BF16, BF16, BF16, BF16> - >; - -// clang-format on - -TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); -TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); -TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); - -#include "test_gemm_universal_ut_cases_bf16.inc" -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) {} - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -} diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp deleted file mode 100644 index 9e99b45e80..0000000000 --- a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "gtest/gtest.h" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "test_gemm_universal_util.hpp" -ck::index_t param_mask = 0xffff; -ck::index_t instance_index = -1; -using F8 = ck::f8_t; -using F16 = ck::half_t; - -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -namespace { - -template -struct tuple_concat; - -template -struct tuple_concat, std::tuple> -{ - using type = std::tuple; -}; - -} // namespace - -template -class TestGemmUniversal_FP16_MK_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP16_MK_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP16_KM_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP16_KM_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -// clang-format off -using KernelTypes_MK_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - -#endif - std::tuple< F16, F16, F16, F16> - >; - -using KernelTypes_MK_NK = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - -#endif - std::tuple< F16, F16, F16, F16> - >; - -using KernelTypes_KM_NK = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16> - >; - -using KernelTypes_KM_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16> - >; -// clang-format on - -TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); -TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_NK, KernelTypes_KM_NK); -TYPED_TEST_SUITE(TestGemmUniversal_FP16_KM_KN, KernelTypes_KM_KN); - -#include "test_gemm_universal_ut_cases_fp16.inc" -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) {} - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -} diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 0f6285cfea..514f8e9668 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -21,11 +21,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") target_compile_options(test_grouped_conv_bwd_data_scale PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(test_grouped_conv_bwd_data_scale PRIVATE gtest_main getopt::getopt utility device_grouped_conv3d_bwd_data_scale_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_interface_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance) -endif() -add_gtest_executable(test_grouped_convnd_bwd_data_interface_wmma test_grouped_convnd_bwd_data_interface_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_interface_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance) + +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) + endif() endif() \ No newline at end of file diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_xdl.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp similarity index 52% rename from test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_xdl.cpp rename to test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp index 969960275f..ab89d9d0f0 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/algorithm.hpp" @@ -32,7 +33,7 @@ static constexpr auto ConvBwdDataDefault = ConvBackwardDataSpecialization::Def static constexpr auto Filter1x1Stride1Pad0 = ConvBackwardDataSpecialization::Filter1x1Stride1Pad0; template -class TestGroupedConvndBwdData : public ::testing::Test +class TestGroupedConvndBwdDataXdl : public ::testing::Test { protected: static constexpr ck::index_t NDimSpatial = 2; @@ -119,6 +120,100 @@ class TestGroupedConvndBwdData : public ::testing::Test } }; +template +class TestGroupedConvndBwdDataWmma : public ::testing::Test +{ + protected: + static constexpr ck::index_t NDimSpatial = 2; + + using OutLayout = std::tuple_element_t<0, Tuple>; + using WeiLayout = std::tuple_element_t<1, Tuple>; + using InLayout = std::tuple_element_t<2, Tuple>; + + // clang-format off + using GroupedConvBwdDataDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + //| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial,OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>; + // clang-format on + + ck::utils::conv::ConvParam conv_param; + + void SetUp() override + { + if(!ck::is_gfx11_supported()) + { + GTEST_SKIP(); + } + } + + template + bool Run() + { + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + std::array out_lengths{}; + std::array out_strides{}; + std::array wei_lengths{}; + std::array wei_strides{}; + std::array in_lengths{}; + std::array in_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), out_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), out_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides); + copy(in_g_n_c_wis_desc.GetLengths(), in_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), in_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto conv = GroupedConvBwdDataDeviceInstance{}; + + auto argument = conv.MakeArgument(nullptr, + nullptr, + std::array{}, + nullptr, + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + Pass{}, + Pass{}, + Pass{}); + return conv.IsSupportedArgument(argument); + } +}; + using GNHWC = ck::tensor_layout::convolution::GNHWC; using NHWGC = ck::tensor_layout::convolution::NHWGC; @@ -131,20 +226,35 @@ using KernelTypes = ::testing::Types, std::tuple>; template -class TestGroupedConvndBwdDataDefault : public TestGroupedConvndBwdData +class TestGroupedConvndBwdDataDefaultXdl + : public TestGroupedConvndBwdDataXdl { }; template -class TestGroupedConvndBwdDataFilter1x1 - : public TestGroupedConvndBwdData +class TestGroupedConvndBwdDataFilter1x1Xdl + : public TestGroupedConvndBwdDataXdl { }; -TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefault, KernelTypes); -TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1, KernelTypes); +template +class TestGroupedConvndBwdDataDefaultWmma + : public TestGroupedConvndBwdDataWmma +{ +}; -TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck) +template +class TestGroupedConvndBwdDataFilter1x1Wmma + : public TestGroupedConvndBwdDataWmma +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefaultXdl, KernelTypes); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1Xdl, KernelTypes); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefaultWmma, KernelTypes); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1Wmma, KernelTypes); + +TYPED_TEST(TestGroupedConvndBwdDataFilter1x1Xdl, SpecializationCheckXdl) { // Check filter 3,3 instead of 1,1 this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; @@ -167,7 +277,30 @@ TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck) EXPECT_TRUE(is_supported); } -TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck) +TYPED_TEST(TestGroupedConvndBwdDataFilter1x1Wmma, SpecializationCheckWmma) +{ + // Check filter 3,3 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check strides 2,2 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check with pad + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Supported version + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_TRUE(is_supported); +} + +TYPED_TEST(TestGroupedConvndBwdDataDefaultXdl, VectorLoadCheckXdl) { // vector load for A this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; @@ -179,7 +312,19 @@ TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck) EXPECT_FALSE(is_supported); } -TYPED_TEST(TestGroupedConvndBwdDataDefault, SplitK) +TYPED_TEST(TestGroupedConvndBwdDataDefaultWmma, VectorLoadCheckWmma) +{ + // vector load for A + this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + // vector load for B, E, Ds + this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} + +TYPED_TEST(TestGroupedConvndBwdDataDefaultXdl, SplitK) { if(ck::is_xdl_supported()) { diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp deleted file mode 100644 index 871c41e706..0000000000 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" - -#include "ck/library/utility/convolution_parameter.hpp" -#include "ck/library/utility/algorithm.hpp" -#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" - -#include - -using DataType = ck::half_t; -using AccDataType = float; -using Pass = ck::tensor_operation::element_wise::PassThrough; - -template -using S = ck::Sequence; -using ConvBackwardDataSpecialization = - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; - -static constexpr auto ConvBwdDataDefault = ConvBackwardDataSpecialization::Default; -static constexpr auto Filter1x1Stride1Pad0 = ConvBackwardDataSpecialization::Filter1x1Stride1Pad0; - -template -class TestGroupedConvndBwdData : public ::testing::Test -{ - protected: - static constexpr ck::index_t NDimSpatial = 2; - - using OutLayout = std::tuple_element_t<0, Tuple>; - using WeiLayout = std::tuple_element_t<1, Tuple>; - using InLayout = std::tuple_element_t<2, Tuple>; - - // clang-format off - using GroupedConvBwdDataDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle - //| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial,OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>; - // clang-format on - - ck::utils::conv::ConvParam conv_param; - - void SetUp() override - { - if(!ck::is_gfx11_supported()) - { - GTEST_SKIP(); - } - } - - template - bool Run() - { - - const auto out_g_n_k_wos_desc = - ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( - conv_param); - - const auto wei_g_k_c_xs_desc = - ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( - conv_param); - - const auto in_g_n_c_wis_desc = - ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( - conv_param); - - std::array out_lengths{}; - std::array out_strides{}; - std::array wei_lengths{}; - std::array wei_strides{}; - std::array in_lengths{}; - std::array in_strides{}; - std::array conv_filter_strides{}; - std::array conv_filter_dilations{}; - std::array input_left_pads{}; - std::array input_right_pads{}; - - auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; - - copy(out_g_n_k_wos_desc.GetLengths(), out_lengths); - copy(out_g_n_k_wos_desc.GetStrides(), out_strides); - copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths); - copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides); - copy(in_g_n_c_wis_desc.GetLengths(), in_lengths); - copy(in_g_n_c_wis_desc.GetStrides(), in_strides); - copy(conv_param.conv_filter_strides_, conv_filter_strides); - copy(conv_param.conv_filter_dilations_, conv_filter_dilations); - copy(conv_param.input_left_pads_, input_left_pads); - copy(conv_param.input_right_pads_, input_right_pads); - - auto conv = GroupedConvBwdDataDeviceInstance{}; - - auto argument = conv.MakeArgument(nullptr, - nullptr, - std::array{}, - nullptr, - out_lengths, - out_strides, - wei_lengths, - wei_strides, - {}, - {}, - in_lengths, - in_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - Pass{}, - Pass{}, - Pass{}); - return conv.IsSupportedArgument(argument); - } -}; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using NHWGC = ck::tensor_layout::convolution::NHWGC; - -using GKYXC = ck::tensor_layout::convolution::GKYXC; - -using GNHWK = ck::tensor_layout::convolution::GNHWK; -using NHWGK = ck::tensor_layout::convolution::NHWGK; - -using KernelTypes = - ::testing::Types, std::tuple>; - -template -class TestGroupedConvndBwdDataDefault : public TestGroupedConvndBwdData -{ -}; - -template -class TestGroupedConvndBwdDataFilter1x1 - : public TestGroupedConvndBwdData -{ -}; - -TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefault, KernelTypes); -TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1, KernelTypes); - -TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck) -{ - // Check filter 3,3 instead of 1,1 - this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; - bool is_supported = this->template Run<2>(); - EXPECT_FALSE(is_supported); - - // Check strides 2,2 instead of 1,1 - this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; - is_supported = this->template Run<2>(); - EXPECT_FALSE(is_supported); - - // Check with pad - this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}; - is_supported = this->template Run<2>(); - EXPECT_FALSE(is_supported); - - // Supported version - this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; - is_supported = this->template Run<2>(); - EXPECT_TRUE(is_supported); -} - -TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck) -{ - // vector load for A - this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; - bool is_supported = this->template Run<2>(); - EXPECT_FALSE(is_supported); - // vector load for B, E, Ds - this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; - is_supported = this->template Run<2>(); - EXPECT_FALSE(is_supported); -}