mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Post PR183 review fixes. (#224)
* Suppress additional warnings for googltest. * Rename file conv_fwd_util to conv_util. * Update includes and ConvParams member access. * Formatting. * Change conv_fwd_util target to conv_util * Fix compiler errors. * Fix leftovers. Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -39,40 +39,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[],
|
||||
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||
ck::utils::conv::ConvParams params;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -133,16 +133,16 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
params.GetOutputSpatialLengths(),
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads);
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_);
|
||||
break;
|
||||
|
||||
case 2:
|
||||
@@ -158,16 +158,16 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
params.GetOutputSpatialLengths(),
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads);
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_);
|
||||
break;
|
||||
|
||||
case 3:
|
||||
@@ -183,16 +183,16 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
|
||||
init_method,
|
||||
do_log,
|
||||
nrepeat,
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
params.GetOutputSpatialLengths(),
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads);
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_);
|
||||
break;
|
||||
|
||||
default: break;
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <vector>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "fill.hpp"
|
||||
#include "profile_convnd_fwd.hpp"
|
||||
|
||||
Reference in New Issue
Block a user