mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Post PR183 review fixes. (#224)
* Suppress additional warnings for googltest. * Rename file conv_fwd_util to conv_util. * Update includes and ConvParams member access. * Formatting. * Change conv_fwd_util target to conv_util * Fix compiler errors. * Fix leftovers. Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -146,19 +146,19 @@ struct ConvParams
|
||||
const std::vector<ck::index_t>& left_pads,
|
||||
const std::vector<ck::index_t>& right_pads);
|
||||
|
||||
ck::index_t num_dim_spatial;
|
||||
ck::index_t N;
|
||||
ck::index_t K;
|
||||
ck::index_t C;
|
||||
ck::index_t num_dim_spatial_;
|
||||
ck::index_t N_;
|
||||
ck::index_t K_;
|
||||
ck::index_t C_;
|
||||
|
||||
std::vector<ck::index_t> filter_spatial_lengths;
|
||||
std::vector<ck::index_t> input_spatial_lengths;
|
||||
std::vector<ck::index_t> filter_spatial_lengths_;
|
||||
std::vector<ck::index_t> input_spatial_lengths_;
|
||||
|
||||
std::vector<ck::index_t> conv_filter_strides;
|
||||
std::vector<ck::index_t> conv_filter_dilations;
|
||||
std::vector<ck::index_t> conv_filter_strides_;
|
||||
std::vector<ck::index_t> conv_filter_dilations_;
|
||||
|
||||
std::vector<ck::index_t> input_left_pads;
|
||||
std::vector<ck::index_t> input_right_pads;
|
||||
std::vector<ck::index_t> input_left_pads_;
|
||||
std::vector<ck::index_t> input_right_pads_;
|
||||
|
||||
std::vector<ck::index_t> GetOutputSpatialLengths() const;
|
||||
};
|
||||
@@ -268,10 +268,10 @@ void run_reference_convolution_forward(const ConvParams& params,
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
@@ -437,17 +437,17 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
|
||||
|
||||
virtual InTensorsTuple GetInputTensors() const override
|
||||
{
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params_.N),
|
||||
static_cast<std::size_t>(params_.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params_.N_),
|
||||
static_cast<std::size_t>(params_.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params_.input_spatial_lengths),
|
||||
std::end(params_.input_spatial_lengths));
|
||||
std::begin(params_.input_spatial_lengths_),
|
||||
std::end(params_.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params_.K),
|
||||
static_cast<std::size_t>(params_.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params_.K_),
|
||||
static_cast<std::size_t>(params_.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params_.filter_spatial_lengths),
|
||||
std::end(params_.filter_spatial_lengths));
|
||||
std::begin(params_.filter_spatial_lengths_),
|
||||
std::end(params_.filter_spatial_lengths_));
|
||||
|
||||
auto input = std::make_unique<Tensor<InDataType>>(
|
||||
get_host_tensor_descriptor(input_dims, InLayout{}));
|
||||
@@ -465,8 +465,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
|
||||
|
||||
virtual TensorPtr<OutDataType> GetOutputTensor() const override
|
||||
{
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params_.N),
|
||||
static_cast<std::size_t>(params_.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params_.N_),
|
||||
static_cast<std::size_t>(params_.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths_),
|
||||
std::end(output_spatial_lengths_));
|
||||
@@ -522,16 +522,16 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
|
||||
static_cast<InDataType*>(in_device_buffers[0]->GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(in_device_buffers[1]->GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buffer->GetDeviceBuffer()),
|
||||
params_.N,
|
||||
params_.K,
|
||||
params_.C,
|
||||
params_.input_spatial_lengths,
|
||||
params_.filter_spatial_lengths,
|
||||
params_.N_,
|
||||
params_.K_,
|
||||
params_.C_,
|
||||
params_.input_spatial_lengths_,
|
||||
params_.filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
params_.conv_filter_strides,
|
||||
params_.conv_filter_dilations,
|
||||
params_.input_left_pads,
|
||||
params_.input_right_pads,
|
||||
params_.conv_filter_strides_,
|
||||
params_.conv_filter_dilations_,
|
||||
params_.input_left_pads_,
|
||||
params_.input_right_pads_,
|
||||
InElementwiseOp{},
|
||||
WeiElementwiseOp{},
|
||||
OutElementwiseOp{});
|
||||
@@ -539,20 +539,20 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
|
||||
|
||||
virtual std::size_t GetFlops() const override
|
||||
{
|
||||
return get_flops(params_.N,
|
||||
params_.C,
|
||||
params_.K,
|
||||
params_.filter_spatial_lengths,
|
||||
return get_flops(params_.N_,
|
||||
params_.C_,
|
||||
params_.K_,
|
||||
params_.filter_spatial_lengths_,
|
||||
output_spatial_lengths_);
|
||||
}
|
||||
|
||||
virtual std::size_t GetBtype() const override
|
||||
{
|
||||
return get_btype<InDataType, WeiDataType, OutDataType>(params_.N,
|
||||
params_.C,
|
||||
params_.K,
|
||||
params_.input_spatial_lengths,
|
||||
params_.filter_spatial_lengths,
|
||||
return get_btype<InDataType, WeiDataType, OutDataType>(params_.N_,
|
||||
params_.C_,
|
||||
params_.K_,
|
||||
params_.input_spatial_lengths_,
|
||||
params_.filter_spatial_lengths_,
|
||||
output_spatial_lengths_);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user