// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_silu_impl.hpp" #include "test_gemm_common.hpp" template class TestGemmAddSilu : public TestGemmD0Common { using ProfileCall = typename TestGemmD0Common::ProfileCall; ProfileCall GetImpl() override { return ck::profiler::profile_gemm_add_silu_impl< typename TestGemmD0Common::ADataType, typename TestGemmD0Common::BDataType, typename TestGemmD0Common::AccDataType, typename TestGemmD0Common::D0DataType, typename TestGemmD0Common::EDataType, typename TestGemmD0Common::ALayout, typename TestGemmD0Common::BLayout, typename TestGemmD0Common::D0Layout, typename TestGemmD0Common::ELayout>; } }; using KernelTypes = ::testing::Types, std::tuple, std::tuple, std::tuple>; TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes); TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_BF16FP16_INT8) { this->Run(); }