mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Add UT for all convolution specializations.
This commit is contained in:
@@ -12,27 +12,56 @@
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs)
|
||||
class Conv2dFwdNHWCInstances : public ::testing::Test
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
public:
|
||||
bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
|
||||
const ck::utils::conv::ConvParams& params)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
|
||||
conv::ConvParams params;
|
||||
params.num_dim_spatial_ = 2;
|
||||
params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3};
|
||||
params.input_spatial_lengths_ = std::vector<ck::index_t>{71, 71};
|
||||
params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2};
|
||||
params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1};
|
||||
params.input_left_pads_ = std::vector<ck::index_t>{1, 1};
|
||||
params.input_right_pads_ = std::vector<ck::index_t>{1, 1};
|
||||
conv::ConvFwdOpInstance<T, T, T> conv_instance(params);
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
return run_engine.Test(conv_ptrs);
|
||||
}
|
||||
|
||||
conv::ConvFwdOpInstance<T, T, T> conv_instance(params);
|
||||
bool test_default()
|
||||
{
|
||||
return test_conv2d_nhwc_instances(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_default_);
|
||||
}
|
||||
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
return run_engine.Test(conv_ptrs);
|
||||
}
|
||||
bool test_filter1x1_stride1_pad0()
|
||||
{
|
||||
return test_conv2d_nhwc_instances(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
|
||||
params_filter1x1_stride1_pad0_);
|
||||
}
|
||||
|
||||
bool test_filter1x1_pad0()
|
||||
{
|
||||
return test_conv2d_nhwc_instances(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
|
||||
params_filter1x1_pad0_);
|
||||
}
|
||||
|
||||
bool test_oddC()
|
||||
{
|
||||
return test_conv2d_nhwc_instances(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_oddC_);
|
||||
}
|
||||
|
||||
static inline ck::utils::conv::ConvParams params_default_;
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
|
||||
2, 4, 256, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
|
||||
2, 4, 256, 128, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
static inline ck::utils::conv::ConvParams params_oddC_{
|
||||
2, 4, 256, 125, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
@@ -60,32 +89,92 @@ TEST(Conv2DFwdNHWC, TestConv2D)
|
||||
EXPECT_TRUE(run_engine.Test(conv_ptrs));
|
||||
}
|
||||
|
||||
TEST(Conv2DFwdNHWC, Bf16Instances)
|
||||
using Conv2dFwdNHWCInstancesTypes = ::testing::Types<ck::bhalf_t, /*ck::half_t,*/ float, int8_t>;
|
||||
TYPED_TEST_SUITE(Conv2dFwdNHWCInstances, Conv2dFwdNHWCInstancesTypes);
|
||||
|
||||
TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_default) { EXPECT_TRUE(this->test_default()); }
|
||||
|
||||
TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<ck::bhalf_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>::Get<2>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0());
|
||||
}
|
||||
|
||||
TEST(Conv2DFwdNHWC, F16Instances)
|
||||
TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<ck::half_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<ck::half_t, ck::half_t, ck::half_t>::Get<2>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0());
|
||||
}
|
||||
|
||||
TEST(Conv2DFwdNHWC, BF32Instances)
|
||||
TYPED_TEST(Conv2dFwdNHWCInstances, conv_spec_oddC) { EXPECT_TRUE(this->test_oddC()); }
|
||||
|
||||
// Workaround for linker error:
|
||||
// ld.lld: error: undefined symbol: _ZTIDF16_
|
||||
|
||||
namespace {
|
||||
|
||||
class Conv2dFwdNHWCInstancesF16 : public ::testing::Test
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<float>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<2>()));
|
||||
|
||||
using T = ck::half_t;
|
||||
|
||||
public:
|
||||
bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
|
||||
const ck::utils::conv::ConvParams& params)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
|
||||
conv::ConvFwdOpInstance<T, T, T> conv_instance(params);
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
return run_engine.Test(conv_ptrs);
|
||||
}
|
||||
|
||||
bool test_default()
|
||||
{
|
||||
return test_conv2d_nhwc_instances(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_default_);
|
||||
}
|
||||
|
||||
bool test_filter1x1_stride1_pad0()
|
||||
{
|
||||
return test_conv2d_nhwc_instances(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
|
||||
params_filter1x1_stride1_pad0_);
|
||||
}
|
||||
|
||||
bool test_filter1x1_pad0()
|
||||
{
|
||||
return test_conv2d_nhwc_instances(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
|
||||
params_filter1x1_pad0_);
|
||||
}
|
||||
|
||||
bool test_oddC()
|
||||
{
|
||||
return test_conv2d_nhwc_instances(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_oddC_);
|
||||
}
|
||||
|
||||
protected:
|
||||
static inline ck::utils::conv::ConvParams params_default_;
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
|
||||
2, 4, 256, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
|
||||
2, 4, 256, 128, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
static inline ck::utils::conv::ConvParams params_oddC_{
|
||||
2, 4, 256, 125, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_default) { EXPECT_TRUE(test_default()); }
|
||||
|
||||
TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_filter1x1_stride1_pad0());
|
||||
}
|
||||
|
||||
TEST(Conv2DFwdNHWC, F32Instances)
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<float>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<2>()));
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_filter1x1_pad0) { EXPECT_TRUE(test_filter1x1_pad0()); }
|
||||
|
||||
TEST(Conv2DFwdNHWC, Int8Instances)
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<int8_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<int8_t, int8_t, int8_t>::Get<2>()));
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstancesF16, conv_spec_oddC) { EXPECT_TRUE(test_oddC()); }
|
||||
|
||||
Reference in New Issue
Block a user