Add wei_strides to grouped conv3d wei to keep consistency (#817)

* Add wei_strides to grouped conv3d wei to keep consistency

* Fix strides in client examples

* Unify backward weight api with forward

* Fix for example

* Fixes for examples

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>

[ROCm/composable_kernel commit: 22443f7aae]
This commit is contained in:
Bartłomiej Kocot
2023-08-07 17:23:45 +02:00
committed by GitHub
parent 446a0cb807
commit 50dbce5272
11 changed files with 357 additions and 433 deletions

View File

@@ -72,10 +72,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
// init to 0
wei_device_buf.SetZero();
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> filter_lengths{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_lengths{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
@@ -84,10 +85,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
@@ -100,14 +102,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.G_,
conv_param.N_,
conv_param.K_,
conv_param.C_,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,