// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_fastgelu_impl.hpp" #include "test_gemm_common.hpp" template class TestGemmAddFastgelu : public TestGemmD0Common { using ProfileCall = typename TestGemmD0Common::ProfileCall; ProfileCall GetImpl() override { return ck::profiler::profile_gemm_add_fastgelu_impl< typename TestGemmD0Common::ADataType, typename TestGemmD0Common::BDataType, typename TestGemmD0Common::AccDataType, typename TestGemmD0Common::D0DataType, typename TestGemmD0Common::EDataType, typename TestGemmD0Common::ALayout, typename TestGemmD0Common::BLayout, typename TestGemmD0Common::D0Layout, typename TestGemmD0Common::ELayout>; } }; using KernelTypes = ::testing::Types, std::tuple, std::tuple, std::tuple, std::tuple, std::tuple>; TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes); TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16_FP16FP16_INT8) { this->Run(); }