diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp index 01d2480fd1..63d9b73901 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp @@ -24,10 +24,10 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto V1 = BlockGemmPipelineVersion::v1; static constexpr auto V3 = BlockGemmPipelineVersion::v3; -// e = elementwise((a * b), d0) -// elementwise(c, d0) = fastgelu(c + d0) +// e = elementwise((a * b), d0, d1) +// elementwise(c, d0, d1) = fastgelu(c + d0 + d1) // output: e[m, n] -// input: a[m, k], b[n, k], d0[m, n] +// input: a[m, k], b[n, k], d0[m, n], d1[m, n] template using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances = diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index 9e061f7c41..88a0cfd0e2 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -20,7 +20,17 @@ if(result EQUAL 0) target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) endif() +add_gtest_executable(test_gemm_fastgelu_wmma test_gemm_fastgelu_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_fastgelu_wmma PRIVATE utility device_gemm_fastgelu_instance) +endif() + add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance) endif() + +add_gtest_executable(test_gemm_add_add_fastgelu_wmma test_gemm_add_add_fastgelu_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_add_fastgelu_wmma PRIVATE utility device_gemm_add_add_fastgelu_instance) +endif() diff --git a/test/gemm_add/test_gemm_add_add_fastgelu_wmma.cpp b/test/gemm_add/test_gemm_add_add_fastgelu_wmma.cpp new file mode 100644 index 0000000000..2cde4c7ea3 --- /dev/null +++ b/test/gemm_add/test_gemm_add_add_fastgelu_wmma.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_add_fastgelu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAddAddFastgelu : public TestGemmD0D1Common +{ + using ProfileCall = typename TestGemmD0D1Common::ProfileCall; + + public: + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_add_fastgelu_impl< + typename TestGemmD0D1Common::ADataType, + typename TestGemmD0D1Common::BDataType, + typename TestGemmD0D1Common::AccDataType, + typename TestGemmD0D1Common::D0DataType, + typename TestGemmD0D1Common::D1DataType, + typename TestGemmD0D1Common::EDataType, + typename TestGemmD0D1Common::ALayout, + typename TestGemmD0D1Common::BLayout, + typename TestGemmD0D1Common::D0Layout, + typename TestGemmD0D1Common::D1Layout, + typename TestGemmD0D1Common::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddAddFastgelu, KernelTypes); +TYPED_TEST(TestGemmAddAddFastgelu, Test_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp b/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp index dfe9b14969..278922412f 100644 --- a/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp @@ -9,29 +9,21 @@ template class TestGemmAddFastgelu : public TestGemmD0Common { - private: - using ADataType = std::tuple_element_t<0, Tuple>; - using BDataType = std::tuple_element_t<1, Tuple>; - using AccDataType = std::tuple_element_t<2, Tuple>; - using D0DataType = std::tuple_element_t<3, Tuple>; - using EDataType = std::tuple_element_t<4, Tuple>; - using ALayout = std::tuple_element_t<5, Tuple>; - using BLayout = std::tuple_element_t<6, Tuple>; - using D0Layout = std::tuple_element_t<7, Tuple>; - using ELayout = std::tuple_element_t<8, Tuple>; + using ProfileCall = typename TestGemmD0Common::ProfileCall; - constexpr static auto ProfileGemmAddFastgeluImpl = - ck::profiler::profile_gemm_add_fastgelu_impl; - - decltype(ProfileGemmAddFastgeluImpl) GetImpl() override { return ProfileGemmAddFastgeluImpl; } + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_fastgelu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } }; using KernelTypes = ::testing::Types, diff --git a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp index 2c055a8006..79e2349088 100644 --- a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp @@ -9,6 +9,22 @@ template class TestGemmAddFastgelu : public TestGemmD0Common { + using ProfileCall = typename TestGemmD0Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_fastgelu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } + private: using ADataType = std::tuple_element_t<0, Tuple>; using BDataType = std::tuple_element_t<1, Tuple>; diff --git a/test/gemm_add/test_gemm_common.hpp b/test/gemm_add/test_gemm_common.hpp index 1cf41d7538..ce0f6a66ea 100644 --- a/test/gemm_add/test_gemm_common.hpp +++ b/test/gemm_add/test_gemm_common.hpp @@ -3,7 +3,6 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" -#include "profiler/profile_gemm_add_impl.hpp" using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -13,6 +12,47 @@ using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +template +class TestGemmCommon : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using BLayout = std::tuple_element_t<5, Tuple>; + using ELayout = std::tuple_element_t<6, Tuple>; + + using ProfileCall = bool(*const)(int, int, bool, bool, int, int, int, int, int, int); + + virtual ProfileCall GetImpl() = 0; + + void Run() + { + std::vector> lengths = { + {16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}}; + + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & + GetImpl()(1, 1, false, false, M, N, K, StrideA, StrideB, StrideE); + } + + EXPECT_TRUE(all_success); + } +}; + template class TestGemmD0Common : public ::testing::Test { @@ -27,17 +67,9 @@ class TestGemmD0Common : public ::testing::Test using D0Layout = std::tuple_element_t<7, Tuple>; using ELayout = std::tuple_element_t<8, Tuple>; - constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl; + using ProfileCall = bool(*const)(int, int, bool, bool, int, int, int, int, int, int, int); - virtual decltype(ProfileGemmAddImpl) GetImpl() = 0; + virtual ProfileCall GetImpl() = 0; void Run() { @@ -58,7 +90,54 @@ class TestGemmD0Common : public ::testing::Test all_success = all_success & - GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE); + GetImpl()(1, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE); + } + + EXPECT_TRUE(all_success); + } +}; + +template +class TestGemmD0D1Common : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using D1DataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + using ALayout = std::tuple_element_t<6, Tuple>; + using BLayout = std::tuple_element_t<7, Tuple>; + using D0Layout = std::tuple_element_t<8, Tuple>; + using D1Layout = std::tuple_element_t<9, Tuple>; + using ELayout = std::tuple_element_t<10, Tuple>; + + using ProfileCall = bool(*const)(int, int, bool, bool, int, int, int, int, int, int, int, int); + + virtual ProfileCall GetImpl() = 0; + + void Run() + { + std::vector> lengths = { + {16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}}; + + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideD0 = ck::is_same_v ? N : M; + int StrideD1 = ck::is_same_v ? N : M; + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & + GetImpl()(1, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE); } EXPECT_TRUE(all_success); diff --git a/test/gemm_add/test_gemm_fastgelu_wmma.cpp b/test/gemm_add/test_gemm_fastgelu_wmma.cpp new file mode 100644 index 0000000000..d8dd218ec6 --- /dev/null +++ b/test/gemm_add/test_gemm_fastgelu_wmma.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_fastgelu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmFastgelu : public TestGemmCommon +{ + using ProfileCall = typename TestGemmCommon::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_fastgelu_impl::ADataType, + typename TestGemmCommon::BDataType, + typename TestGemmCommon::AccDataType, + typename TestGemmCommon::EDataType, + typename TestGemmCommon::ALayout, + typename TestGemmCommon::BLayout, + typename TestGemmCommon::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmFastgelu, KernelTypes); +TYPED_TEST(TestGemmFastgelu, Test_BF16FP16) { this->Run(); }