// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include // Include the helper file we're testing #include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" namespace { namespace ckb = ::ck_tile::builder; using ::ck_tile::builder::factory::internal::ConvTensorLayouts; using ::ck_tile::builder::factory::internal::GetTensorLayout; using enum ::ck_tile::builder::ConvDirection; TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) { using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); } } // namespace