// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_impl.hpp" #include "test_gemm_common.hpp" template class TestGemmAdd : public TestGemmD0Common { using ProfileCall = typename TestGemmD0Common::ProfileCall; ProfileCall GetImpl() override { return ck::profiler::profile_gemm_add_impl::ADataType, typename TestGemmD0Common::BDataType, typename TestGemmD0Common::AccDataType, typename TestGemmD0Common::D0DataType, typename TestGemmD0Common::EDataType, typename TestGemmD0Common::ALayout, typename TestGemmD0Common::BLayout, typename TestGemmD0Common::D0Layout, typename TestGemmD0Common::ELayout>; } }; using KernelTypes = ::testing::Types, std::tuple, std::tuple, std::tuple>; TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); TYPED_TEST(TestGemmAdd, Test_BF16FP16_FP16FP16_INT8) { this->Run(); }