Post PR183 review fixes. (#224)

* Suppress additional warnings for googltest.

* Rename file conv_fwd_util to conv_util.

* Update includes and ConvParams member access.

* Formatting.

* Change conv_fwd_util target to conv_util

* Fix compiler errors.

* Fix leftovers.

Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
Adam Osewski
2022-05-10 22:41:29 +02:00
committed by GitHub
parent f03a1738d9
commit 712e464c4e
35 changed files with 843 additions and 840 deletions

View File

@@ -146,19 +146,19 @@ struct ConvParams
const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads);
ck::index_t num_dim_spatial;
ck::index_t N;
ck::index_t K;
ck::index_t C;
ck::index_t num_dim_spatial_;
ck::index_t N_;
ck::index_t K_;
ck::index_t C_;
std::vector<ck::index_t> filter_spatial_lengths;
std::vector<ck::index_t> input_spatial_lengths;
std::vector<ck::index_t> filter_spatial_lengths_;
std::vector<ck::index_t> input_spatial_lengths_;
std::vector<ck::index_t> conv_filter_strides;
std::vector<ck::index_t> conv_filter_dilations;
std::vector<ck::index_t> conv_filter_strides_;
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads;
std::vector<ck::index_t> input_right_pads;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
std::vector<ck::index_t> GetOutputSpatialLengths() const;
};
@@ -268,10 +268,10 @@ void run_reference_convolution_forward(const ConvParams& params,
auto ref_argument = ref_conv.MakeArgument(input,
weights,
output,
params.conv_filter_strides,
params.conv_filter_dilations,
params.input_left_pads,
params.input_right_pads,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
PassThrough{},
PassThrough{},
PassThrough{});
@@ -437,17 +437,17 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
virtual InTensorsTuple GetInputTensors() const override
{
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params_.N),
static_cast<std::size_t>(params_.C)};
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params_.N_),
static_cast<std::size_t>(params_.C_)};
input_dims.insert(std::end(input_dims),
std::begin(params_.input_spatial_lengths),
std::end(params_.input_spatial_lengths));
std::begin(params_.input_spatial_lengths_),
std::end(params_.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params_.K),
static_cast<std::size_t>(params_.C)};
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params_.K_),
static_cast<std::size_t>(params_.C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(params_.filter_spatial_lengths),
std::end(params_.filter_spatial_lengths));
std::begin(params_.filter_spatial_lengths_),
std::end(params_.filter_spatial_lengths_));
auto input = std::make_unique<Tensor<InDataType>>(
get_host_tensor_descriptor(input_dims, InLayout{}));
@@ -465,8 +465,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
virtual TensorPtr<OutDataType> GetOutputTensor() const override
{
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params_.N),
static_cast<std::size_t>(params_.K)};
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params_.N_),
static_cast<std::size_t>(params_.K_)};
output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_));
@@ -522,16 +522,16 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
static_cast<InDataType*>(in_device_buffers[0]->GetDeviceBuffer()),
static_cast<WeiDataType*>(in_device_buffers[1]->GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buffer->GetDeviceBuffer()),
params_.N,
params_.K,
params_.C,
params_.input_spatial_lengths,
params_.filter_spatial_lengths,
params_.N_,
params_.K_,
params_.C_,
params_.input_spatial_lengths_,
params_.filter_spatial_lengths_,
output_spatial_lengths_,
params_.conv_filter_strides,
params_.conv_filter_dilations,
params_.input_left_pads,
params_.input_right_pads,
params_.conv_filter_strides_,
params_.conv_filter_dilations_,
params_.input_left_pads_,
params_.input_right_pads_,
InElementwiseOp{},
WeiElementwiseOp{},
OutElementwiseOp{});
@@ -539,20 +539,20 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
virtual std::size_t GetFlops() const override
{
return get_flops(params_.N,
params_.C,
params_.K,
params_.filter_spatial_lengths,
return get_flops(params_.N_,
params_.C_,
params_.K_,
params_.filter_spatial_lengths_,
output_spatial_lengths_);
}
virtual std::size_t GetBtype() const override
{
return get_btype<InDataType, WeiDataType, OutDataType>(params_.N,
params_.C,
params_.K,
params_.input_spatial_lengths,
params_.filter_spatial_lengths,
return get_btype<InDataType, WeiDataType, OutDataType>(params_.N_,
params_.C_,
params_.K_,
params_.input_spatial_lengths_,
params_.filter_spatial_lengths_,
output_spatial_lengths_);
}