From bcf93e292c77e26fad1208bd024357a0ce39a455 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Mon, 2 Jun 2025 15:09:31 +0500 Subject: [PATCH] Prepare gemma_add tests for adding wmma --- test/gemm_add/CMakeLists.txt | 16 ++--- test/gemm_add/test_gemm_add_fastgelu_xdl.cpp | 6 +- test/gemm_add/test_gemm_add_relu_xdl.cpp | 6 +- test/gemm_add/test_gemm_add_silu_xdl.cpp | 6 +- test/gemm_add/test_gemm_add_xdl.hpp | 42 ++----------- test/gemm_add/test_gemm_common.hpp | 66 ++++++++++++++++++++ 6 files changed, 88 insertions(+), 54 deletions(-) create mode 100644 test/gemm_add/test_gemm_common.hpp diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index ab4c781847..7b5fa74ca2 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -1,19 +1,19 @@ -add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp) +add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) + target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance) endif() -add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp) +add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) + target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) endif() -add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp) +add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) + target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) endif() -add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp) +add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) + target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) endif() diff --git a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp index 1b12ab7528..2c055a8006 100644 --- a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp @@ -1,13 +1,13 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_fastgelu_impl.hpp" -#include "test_gemm_add_xdl.hpp" +#include "test_gemm_common.hpp" template -class TestGemmAddFastgelu : public TestGemmAdd +class TestGemmAddFastgelu : public TestGemmD0Common { private: using ADataType = std::tuple_element_t<0, Tuple>; diff --git a/test/gemm_add/test_gemm_add_relu_xdl.cpp b/test/gemm_add/test_gemm_add_relu_xdl.cpp index e8b769b1cb..35aaba96b1 100644 --- a/test/gemm_add/test_gemm_add_relu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_relu_xdl.cpp @@ -1,13 +1,13 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_relu_impl.hpp" -#include "test_gemm_add_xdl.hpp" +#include "test_gemm_common.hpp" template -class TestGemmAddRelu : public TestGemmAdd +class TestGemmAddRelu : public TestGemmD0Common { private: using ADataType = std::tuple_element_t<0, Tuple>; diff --git a/test/gemm_add/test_gemm_add_silu_xdl.cpp b/test/gemm_add/test_gemm_add_silu_xdl.cpp index 75fa59a8e7..8d242869c6 100644 --- a/test/gemm_add/test_gemm_add_silu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_silu_xdl.cpp @@ -1,13 +1,13 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_silu_impl.hpp" -#include "test_gemm_add_xdl.hpp" +#include "test_gemm_common.hpp" template -class TestGemmAddSilu : public TestGemmAdd +class TestGemmAddSilu : public TestGemmD0Common { private: using ADataType = std::tuple_element_t<0, Tuple>; diff --git a/test/gemm_add/test_gemm_add_xdl.hpp b/test/gemm_add/test_gemm_add_xdl.hpp index 11d3d1c10a..3cc5405b5f 100644 --- a/test/gemm_add/test_gemm_add_xdl.hpp +++ b/test/gemm_add/test_gemm_add_xdl.hpp @@ -1,22 +1,15 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_impl.hpp" - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using I8 = int8_t; -using BF16 = ck::bhalf_t; -using F16 = ck::half_t; -using F32 = float; +#include "test_gemm_common.hpp" template -class TestGemmAdd : public ::testing::Test +class TestGemmAdd : public TestGemmD0Common { - protected: + private: using ADataType = std::tuple_element_t<0, Tuple>; using BDataType = std::tuple_element_t<1, Tuple>; using AccDataType = std::tuple_element_t<2, Tuple>; @@ -37,32 +30,7 @@ class TestGemmAdd : public ::testing::Test D0Layout, ELayout>; - virtual decltype(ProfileGemmAddImpl) GetImpl() { return ProfileGemmAddImpl; } - - void Run() - { - std::vector> lengths = { - {16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}}; - - bool all_success = true; - - for(auto length : lengths) - { - int M = length[0]; - int N = length[1]; - int K = length[2]; - int StrideA = ck::is_same_v ? K : M; - int StrideB = ck::is_same_v ? N : K; - int StrideD0 = ck::is_same_v ? N : M; - int StrideE = ck::is_same_v ? N : M; - - all_success = - all_success & - GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE); - } - - EXPECT_TRUE(all_success); - } + decltype(ProfileGemmAddImpl) GetImpl() override { return ProfileGemmAddImpl; } }; using KernelTypes = ::testing::Types, diff --git a/test/gemm_add/test_gemm_common.hpp b/test/gemm_add/test_gemm_common.hpp new file mode 100644 index 0000000000..1cf41d7538 --- /dev/null +++ b/test/gemm_add/test_gemm_common.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_impl.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +class TestGemmD0Common : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl; + + virtual decltype(ProfileGemmAddImpl) GetImpl() = 0; + + void Run() + { + std::vector> lengths = { + {16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}}; + + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideD0 = ck::is_same_v ? N : M; + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & + GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE); + } + + EXPECT_TRUE(all_success); + } +};