diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 065235349c..f8974c24da 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -35,6 +35,7 @@ add_ck_factory_test(test_testing_utils test_testing_utils.cpp) add_ck_factory_test(test_ck_factory_grouped_convolution_forward test_ck_factory_grouped_convolution_forward.cpp) add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bilinear test_ck_factory_grouped_convolution_forward_bilinear.cpp) add_ck_factory_test(test_ck_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp) +add_ck_factory_test(test_ck_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scale test_ck_factory_grouped_convolution_forward_scale.cpp) add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scaleadd_ab test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp) add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bias_clamp test_ck_factory_grouped_convolution_forward_bias_clamp.cpp) diff --git a/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp new file mode 100644 index 0000000000..b857ad05eb --- /dev/null +++ b/experimental/builder/test/test_ck_factory_grouped_convolution_forward_dynamic_op.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "ck/utility/data_type.hpp" +#include "testing_utils.hpp" + +using ck_tile::test::InstanceSet; +using ck_tile::test::InstancesMatch; + +namespace { + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using DsLayout = ck::Tuple; + +using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD; +using ck::tensor_operation::element_wise::DynamicUnaryOp; +using ck::tensor_operation::element_wise::PassThrough; + +template +using DeviceOp = DeviceGroupedConvFwdMultipleABD, + type, // OutDataType + PassThrough, + PassThrough, + DynamicUnaryOp, + type>; + +} // namespace + +template +struct CkFactoryTestBilinearFwd : public testing::Test +{ + static auto get_actual_instances() + { + return InstanceSet::from_factory(); + } + + static auto get_expected_instances() { return InstanceSet(Case::expected); } +}; + +struct DyOp_F32_2 +{ + using DeviceOp = ::DeviceOp<2, float>; + + constexpr static auto expected = { + // clang-format off + "" + // clang-format on + }; +}; + +struct DyOp_F32_3 +{ + using DeviceOp = ::DeviceOp<3, float>; + + constexpr static auto expected = { + // clang-format off + "" + // clang-format on + }; +}; + +struct DyOp_F16_2 +{ + using DeviceOp = ::DeviceOp<2, ck::half_t>; + + constexpr static auto expected = { + // clang-format off + "" + // clang-format on + }; +}; + +struct DyOp_F16_3 +{ + using DeviceOp = ::DeviceOp<3, ck::half_t>; + + constexpr static auto expected = { + // clang-format off + "" + // clang-format on + }; +}; + +struct DyOp_BF16_2 +{ + using DeviceOp = ::DeviceOp<2, ck::bhalf_t>; + + constexpr static auto expected = { + // clang-format off + "" + // clang-format on + }; +}; + +struct DyOp_BF16_3 +{ + using DeviceOp = ::DeviceOp<3, ck::bhalf_t>; + + constexpr static auto expected = { + // clang-format off + "" + // clang-format on + }; +}; + +struct DyOp_INT8_2 +{ + using DeviceOp = ::DeviceOp<2, int8_t>; + + constexpr static auto expected = { + // clang-format off + "" + // clang-format on + }; +}; + +struct DyOp_INT8_3 +{ + using DeviceOp = ::DeviceOp<3, int8_t>; + + constexpr static auto expected = { + // clang-format off + "" + // clang-format on + }; +}; + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(CkFactoryTestBilinearFwd, TestTypes); + +TYPED_TEST(CkFactoryTestBilinearFwd, TestInstances) +{ + auto actual = TestFixture::get_actual_instances(); + auto expected = TestFixture::get_expected_instances(); + + EXPECT_THAT(actual, InstancesMatch(expected)); +}