mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Post PR183 review fixes. (#224)
* Suppress additional warnings for googltest. * Rename file conv_fwd_util to conv_util. * Update includes and ConvParams member access. * Formatting. * Change conv_fwd_util target to conv_util * Fix compiler errors. * Fix leftovers. Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp)
|
||||
target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_fwd_util)
|
||||
target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_util)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_fwd_util)
|
||||
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_fwd_util)
|
||||
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_util)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -134,40 +134,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -199,21 +199,21 @@ int main(int argc, char* argv[])
|
||||
params = parse_conv_params(num_dim_spatial, argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -255,16 +255,16 @@ int main(int argc, char* argv[])
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -279,13 +279,13 @@ int main(int argc, char* argv[])
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
|
||||
std::size_t flop = get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype =
|
||||
get_btype<InDataType, WeiDataType, OutDataType>(params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
get_btype<InDataType, WeiDataType, OutDataType>(params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -301,10 +301,10 @@ int main(int argc, char* argv[])
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
host_output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -137,40 +137,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
|
||||
params = parse_conv_params(num_dim_spatial, argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -256,16 +256,16 @@ int main(int argc, char* argv[])
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -280,13 +280,13 @@ int main(int argc, char* argv[])
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
|
||||
std::size_t flop = get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype = get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -302,10 +302,10 @@ int main(int argc, char* argv[])
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
host_output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -139,40 +139,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -204,21 +204,21 @@ int main(int argc, char* argv[])
|
||||
params = parse_conv_params(num_dim_spatial, argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -258,16 +258,16 @@ int main(int argc, char* argv[])
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -282,13 +282,13 @@ int main(int argc, char* argv[])
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
|
||||
std::size_t flop = get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype = get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -304,10 +304,10 @@ int main(int argc, char* argv[])
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
host_output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
Reference in New Issue
Block a user