From bc8f54299a3827c624866c48076ca4cb6759bd14 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 26 May 2022 15:34:10 +0200 Subject: [PATCH] Add UT for all convolution specializations. --- test/convnd_fwd/conv2d_fwd.cpp | 161 +++++++++++++++++++++++++-------- 1 file changed, 125 insertions(+), 36 deletions(-) diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp index 05e46147be..7d47289675 100644 --- a/test/convnd_fwd/conv2d_fwd.cpp +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -12,27 +12,56 @@ namespace { template -bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs) +class Conv2dFwdNHWCInstances : public ::testing::Test { - using namespace std::placeholders; - using namespace ck::utils; + public: + bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; - conv::ConvParams params; - params.num_dim_spatial_ = 2; - params.filter_spatial_lengths_ = std::vector{3, 3}; - params.input_spatial_lengths_ = std::vector{71, 71}; - params.conv_filter_strides_ = std::vector{2, 2}; - params.conv_filter_dilations_ = std::vector{1, 1}; - params.input_left_pads_ = std::vector{1, 1}; - params.input_right_pads_ = std::vector{1, 1}; + conv::ConvFwdOpInstance conv_instance(params); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + return run_engine.Test(conv_ptrs); + } - conv::ConvFwdOpInstance conv_instance(params); + bool test_default() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), params_default_); + } - auto reference_conv_fwd_fun = - std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); -} + bool test_filter1x1_stride1_pad0() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_stride1_pad0_); + } + + bool test_filter1x1_pad0() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_pad0_); + } + + bool test_oddC() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), params_oddC_); + } + + static inline ck::utils::conv::ConvParams params_default_; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 2, 4, 256, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 2, 4, 256, 128, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_oddC_{ + 2, 4, 256, 125, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; +}; } // anonymous namespace @@ -60,32 +89,92 @@ TEST(Conv2DFwdNHWC, TestConv2D) EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -TEST(Conv2DFwdNHWC, Bf16Instances) +using Conv2dFwdNHWCInstancesTypes = ::testing::Types; +TYPED_TEST_SUITE(Conv2dFwdNHWCInstances, Conv2dFwdNHWCInstancesTypes); + +TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_default) { EXPECT_TRUE(this->test_default()); } + +TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); } -TEST(Conv2DFwdNHWC, F16Instances) +TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_filter1x1_pad0) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv2DFwdNHWC, BF32Instances) +TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_oddC) { EXPECT_TRUE(this->test_oddC()); } + +// Workaround for linker error: +// ld.lld: error: undefined symbol: _ZTIDF16_ + +namespace { + +class Conv2dFwdNHWCInstancesF16 : public ::testing::Test { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + + using T = ck::half_t; + + public: + bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + + conv::ConvFwdOpInstance conv_instance(params); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + return run_engine.Test(conv_ptrs); + } + + bool test_default() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), params_default_); + } + + bool test_filter1x1_stride1_pad0() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_stride1_pad0_); + } + + bool test_filter1x1_pad0() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_pad0_); + } + + bool test_oddC() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), params_oddC_); + } + + protected: + static inline ck::utils::conv::ConvParams params_default_; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 2, 4, 256, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 2, 4, 256, 128, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_oddC_{ + 2, 4, 256, 125, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; +}; + +} // namespace + +TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_default) { EXPECT_TRUE(test_default()); } + +TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(test_filter1x1_stride1_pad0()); } -TEST(Conv2DFwdNHWC, F32Instances) -{ - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); -} +TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_filter1x1_pad0) { EXPECT_TRUE(test_filter1x1_pad0()); } -TEST(Conv2DFwdNHWC, Int8Instances) -{ - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); -} +TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_oddC) { EXPECT_TRUE(test_oddC()); }