mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Support NHWGC conv2d_bwd_weight (#769)
* Support NHWGC conv2d_bwd_weight * Fix client example * Fix client example * Fix comments * Redesign grouped_conv_bwd_weight instances * Clang format fix --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
} // function end
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
InElementwiseOperation c_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_G_;
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
const index_t Conv_G_;
|
||||
const index_t Conv_N_;
|
||||
const index_t Conv_K_;
|
||||
const index_t Conv_C_;
|
||||
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -126,6 +126,9 @@ __global__ void
|
||||
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
@@ -161,29 +164,19 @@ template <ck::index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle;
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle;
|
||||
|
||||
using ADataType = OutDataType;
|
||||
using BDataType = InDataType;
|
||||
@@ -222,17 +215,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -282,14 +277,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -372,19 +367,25 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
template <ck::index_t NDim,
|
||||
typename ck::enable_if<NDim == 2 &&
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>,
|
||||
bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -447,14 +448,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -539,19 +540,202 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDim,
|
||||
typename ck::enable_if<NDim == 2 &&
|
||||
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y;
|
||||
|
||||
const index_t NStride = input_strides[1];
|
||||
const index_t HiStride = input_strides[3];
|
||||
const index_t WiStride = input_strides[4];
|
||||
const auto CStride = input_strides[2];
|
||||
|
||||
const index_t WoStride = output_strides[4];
|
||||
const auto KStride = Number<1>{};
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Ho * Wo, K), make_tuple(WoStride, KStride));
|
||||
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Hi * Wi, C), make_tuple(WiStride, CStride));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Ho * Wo, K), make_tuple(WoStride, KStride));
|
||||
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -621,14 +805,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -725,31 +909,70 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
|
||||
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
|
||||
const ck::index_t dim = 1;
|
||||
const ck::index_t batch = 1;
|
||||
const std::array<ck::index_t, NDimSpatial> lengths{1};
|
||||
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial> params{1};
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
batch);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
|
||||
const ck::index_t dim = 1;
|
||||
const ck::index_t batch = 1;
|
||||
const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial> params{1, 1};
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
batch);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
|
||||
1,
|
||||
1,
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
1);
|
||||
const ck::index_t dim = 1;
|
||||
const ck::index_t batch = 1;
|
||||
const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
batch);
|
||||
}
|
||||
|
||||
// type convert descs
|
||||
@@ -863,19 +1086,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t M01,
|
||||
const ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -913,6 +1138,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -927,18 +1154,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ =
|
||||
N * K *
|
||||
std::accumulate(begin(output_spatial_lengths),
|
||||
end(output_spatial_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ =
|
||||
N * C *
|
||||
std::accumulate(begin(input_spatial_lengths),
|
||||
end(input_spatial_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = output_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = input_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
K * C *
|
||||
std::accumulate(begin(filter_spatial_lengths),
|
||||
@@ -977,16 +1194,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
WeiElementwiseOperation c_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_G_;
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads_;
|
||||
index_t k_batch_;
|
||||
const index_t Conv_G_;
|
||||
const index_t Conv_N_;
|
||||
const index_t Conv_K_;
|
||||
const index_t Conv_C_;
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -1091,6 +1308,45 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!(is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNWK>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout,
|
||||
tensor_layout::convolution::
|
||||
GNHWK>)&&!(is_same_v<InLayout,
|
||||
tensor_layout::convolution::NHWGC> &&
|
||||
is_same_v<WeiLayout,
|
||||
tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout,
|
||||
tensor_layout::convolution::NHWGK>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
@@ -1134,21 +1390,23 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k)
|
||||
const ck::index_t split_k)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
@@ -1160,6 +1418,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -1178,21 +1438,23 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k) override
|
||||
const ck::index_t split_k) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<WeiDataType*>(p_wei_grid),
|
||||
@@ -1204,6 +1466,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -1226,7 +1490,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle"
|
||||
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
Reference in New Issue
Block a user