Support NHWGC conv2d_bwd_weight (#769)

* Support NHWGC conv2d_bwd_weight

* Fix client example

* Fix client example

* Fix comments

* Redesign grouped_conv_bwd_weight instances

* Clang format fix

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>

[ROCm/composable_kernel commit: 1ee99dcaa6]
This commit is contained in:
Bartłomiej Kocot
2023-07-12 15:25:02 +02:00
committed by GitHub
parent 5b106bca9b
commit dd6d0dd3b9
33 changed files with 1547 additions and 1035 deletions

View File

@@ -3,7 +3,7 @@
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using InDataType = BF16;
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
@@ -17,8 +17,20 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
NDimSpatial, // NDimSpatial
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType

View File

@@ -3,7 +3,7 @@
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using InDataType = F16;
using WeiDataType = F16;
@@ -16,8 +16,20 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
NDimSpatial, // NDimSpatial
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType

View File

@@ -75,6 +75,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
@@ -85,6 +87,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
@@ -103,6 +107,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,