// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_quantization_impl.hpp" #include "test_gemm_quantization_util.hpp" using I8 = int8_t; using I32 = int32_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; template class TestGemmQuantization : public ck::test::TestGemmQuantizationCommon { protected: using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, float); ProfileCall GetImpl() override { return &ck::profiler::profile_gemm_quantization_impl< typename ck::test::TestGemmQuantizationCommon::ADataType, typename ck::test::TestGemmQuantizationCommon::BDataType, typename ck::test::TestGemmQuantizationCommon::AccDataType, typename ck::test::TestGemmQuantizationCommon::EDataType, typename ck::test::TestGemmQuantizationCommon::ALayout, typename ck::test::TestGemmQuantizationCommon::BLayout, typename ck::test::TestGemmQuantizationCommon::ELayout>; } }; using KernelTypes = ::testing::Types, std::tuple, std::tuple, std::tuple>; TYPED_TEST_SUITE(TestGemmQuantization, KernelTypes); #include "test_gemm_quantization_ut_cases.inc"