// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_fastgelu_impl.hpp" #include "test_gemm_common.hpp" template class TestGemmFastgelu : public TestGemmCommon { using ProfileCall = typename TestGemmCommon::ProfileCall; ProfileCall GetImpl() override { return ck::profiler::profile_gemm_fastgelu_impl::ADataType, typename TestGemmCommon::BDataType, typename TestGemmCommon::AccDataType, typename TestGemmCommon::EDataType, typename TestGemmCommon::ALayout, typename TestGemmCommon::BLayout, typename TestGemmCommon::ELayout>; } }; using KernelTypes = ::testing::Types, std::tuple, std::tuple, std::tuple>; TYPED_TEST_SUITE(TestGemmFastgelu, KernelTypes); TYPED_TEST(TestGemmFastgelu, Test_BF16FP16) { this->Run(); }