mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Post PR183 review fixes. (#224)
* Suppress additional warnings for googltest. * Rename file conv_fwd_util to conv_util. * Update includes and ConvParams member access. * Formatting. * Change conv_fwd_util target to conv_util * Fix compiler errors. * Fix leftovers. Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_fwd_util)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_util)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <half.hpp>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -105,40 +105,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
|
||||
int num_dim_spatial = 2;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
params.C = 128;
|
||||
params.C_ = 128;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -263,16 +263,16 @@ int main(int argc, char* argv[])
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -287,13 +287,13 @@ int main(int argc, char* argv[])
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
|
||||
std::size_t flop = ck::utils::conv::get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -310,10 +310,10 @@ int main(int argc, char* argv[])
|
||||
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result,
|
||||
wei_k_c_y_x,
|
||||
out_n_k_ho_wo,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
Reference in New Issue
Block a user