Add UT for all convolution specializations.

This commit is contained in:
Adam Osewski
2022-05-26 15:34:10 +02:00
parent 4fd60bf4c3
commit bc8f54299a

View File

@@ -12,27 +12,56 @@
namespace {
template <typename T>
bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs)
class Conv2dFwdNHWCInstances : public ::testing::Test
{
using namespace std::placeholders;
using namespace ck::utils;
public:
bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
const ck::utils::conv::ConvParams& params)
{
using namespace std::placeholders;
using namespace ck::utils;
conv::ConvParams params;
params.num_dim_spatial_ = 2;
params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3};
params.input_spatial_lengths_ = std::vector<ck::index_t>{71, 71};
params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2};
params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1};
params.input_left_pads_ = std::vector<ck::index_t>{1, 1};
params.input_right_pads_ = std::vector<ck::index_t>{1, 1};
conv::ConvFwdOpInstance<T, T, T> conv_instance(params);
auto reference_conv_fwd_fun =
std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
return run_engine.Test(conv_ptrs);
}
conv::ConvFwdOpInstance<T, T, T> conv_instance(params);
bool test_default()
{
return test_conv2d_nhwc_instances(
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_default_);
}
auto reference_conv_fwd_fun =
std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
return run_engine.Test(conv_ptrs);
}
bool test_filter1x1_stride1_pad0()
{
return test_conv2d_nhwc_instances(
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
params_filter1x1_stride1_pad0_);
}
bool test_filter1x1_pad0()
{
return test_conv2d_nhwc_instances(
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
params_filter1x1_pad0_);
}
bool test_oddC()
{
return test_conv2d_nhwc_instances(
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_oddC_);
}
static inline ck::utils::conv::ConvParams params_default_;
static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
2, 4, 256, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
2, 4, 256, 128, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
static inline ck::utils::conv::ConvParams params_oddC_{
2, 4, 256, 125, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
};
} // anonymous namespace
@@ -60,32 +89,92 @@ TEST(Conv2DFwdNHWC, TestConv2D)
EXPECT_TRUE(run_engine.Test(conv_ptrs));
}
TEST(Conv2DFwdNHWC, Bf16Instances)
using Conv2dFwdNHWCInstancesTypes = ::testing::Types<ck::bhalf_t, /*ck::half_t,*/ float, int8_t>;
TYPED_TEST_SUITE(Conv2dFwdNHWCInstances, Conv2dFwdNHWCInstancesTypes);
TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_default) { EXPECT_TRUE(this->test_default()); }
TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_filter1x1_stride1_pad0)
{
EXPECT_TRUE(test_conv2d_nhwc_instances<ck::bhalf_t>(
ck::utils::conv::ConvolutionFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>::Get<2>()));
EXPECT_TRUE(this->test_filter1x1_stride1_pad0());
}
TEST(Conv2DFwdNHWC, F16Instances)
TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_filter1x1_pad0)
{
EXPECT_TRUE(test_conv2d_nhwc_instances<ck::half_t>(
ck::utils::conv::ConvolutionFwdInstances<ck::half_t, ck::half_t, ck::half_t>::Get<2>()));
EXPECT_TRUE(this->test_filter1x1_pad0());
}
TEST(Conv2DFwdNHWC, BF32Instances)
TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_oddC) { EXPECT_TRUE(this->test_oddC()); }
// Workaround for linker error:
// ld.lld: error: undefined symbol: _ZTIDF16_
namespace {
class Conv2dFwdNHWCInstancesF16 : public ::testing::Test
{
EXPECT_TRUE(test_conv2d_nhwc_instances<float>(
ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<2>()));
using T = ck::half_t;
public:
bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
const ck::utils::conv::ConvParams& params)
{
using namespace std::placeholders;
using namespace ck::utils;
conv::ConvFwdOpInstance<T, T, T> conv_instance(params);
auto reference_conv_fwd_fun =
std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
return run_engine.Test(conv_ptrs);
}
bool test_default()
{
return test_conv2d_nhwc_instances(
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_default_);
}
bool test_filter1x1_stride1_pad0()
{
return test_conv2d_nhwc_instances(
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
params_filter1x1_stride1_pad0_);
}
bool test_filter1x1_pad0()
{
return test_conv2d_nhwc_instances(
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
params_filter1x1_pad0_);
}
bool test_oddC()
{
return test_conv2d_nhwc_instances(
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_oddC_);
}
protected:
static inline ck::utils::conv::ConvParams params_default_;
static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
2, 4, 256, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
2, 4, 256, 128, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
static inline ck::utils::conv::ConvParams params_oddC_{
2, 4, 256, 125, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
};
} // namespace
TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_default) { EXPECT_TRUE(test_default()); }
TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_filter1x1_stride1_pad0)
{
EXPECT_TRUE(test_filter1x1_stride1_pad0());
}
TEST(Conv2DFwdNHWC, F32Instances)
{
EXPECT_TRUE(test_conv2d_nhwc_instances<float>(
ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<2>()));
}
TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_filter1x1_pad0) { EXPECT_TRUE(test_filter1x1_pad0()); }
TEST(Conv2DFwdNHWC, Int8Instances)
{
EXPECT_TRUE(test_conv2d_nhwc_instances<int8_t>(
ck::utils::conv::ConvolutionFwdInstances<int8_t, int8_t, int8_t>::Get<2>()));
}
TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_oddC) { EXPECT_TRUE(test_oddC()); }