Unified implementation of 1d/2d/3d conv bwd-data. fp32/fp16/bfp16/int8 (#134)

* start convnd bwd data

* add 3d laoyout name

* add conv1d reference

* add con3d reference

* finished example client code

* conv1d kernel finished

* fix input error

* add conv3d

* add 3d layout in conv_utils.hpp

* fix sepecial check

* addconvnd lib

* add test for bwd data

* finished test

* add check slice length

* convnd bwd data start

* profiler can be compiled

* fix some bug

* set input to zero

* modify readme for example

* fix test_convnd_bwd_data bug

* test_convnd_bwd_data parameter desc

* workaround for 1d

* workaroud for 2d

* change init value

* workaround for 3d int8

* fix init value bug

* remove workaround

* fix acc data type

* add int32

* change select function to template

* tilda to tilde

* remove int32 instance

* fix commit for device hpp

* fix comments for profiler

* using profile imp to test

* add pass verification

* fix conv2d reference

* fix conflict

* remove double batched_gemm

* fix exampel conv2d data and test convnd

* format

* change conv2d_bwd_data return value

* remove repeat = 1

* remove conv bwd data

Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
ltqin
2022-03-29 23:52:25 +08:00
committed by GitHub
parent fe6ce55c24
commit 0536f2b312
37 changed files with 4577 additions and 245 deletions

View File

@@ -7,9 +7,9 @@
namespace ck {
// Number of GEMMs = YTilda * XTilda
// Number of GEMMs = YTilde * XTilde
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmN = N * HTildeSlice * WTildeSlice
// GemmK = K * YDotSlice * XDotSlice
template <typename... Wei,
typename... In,
@@ -18,8 +18,8 @@ template <typename... Wei,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t IYTildaValue,
index_t IXTildaValue,
index_t IYTildeValue,
index_t IXTildeValue,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
@@ -30,8 +30,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<IYTildaValue>,
Number<IXTildaValue>,
Number<IYTildeValue>,
Number<IXTildeValue>,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
@@ -40,8 +40,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
constexpr auto IYTilda = Number<IYTildaValue>{};
constexpr auto IXTilda = Number<IXTildaValue>{};
constexpr auto IYTilde = Number<IYTildeValue>{};
constexpr auto IXTilde = Number<IXTildeValue>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
@@ -71,55 +71,55 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilda);
const auto XDot = math::integer_divide_ceil(X, XTilda);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
const auto IHTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
const auto IWTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildaSliceEnd =
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildaSliceEnd =
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto IHTildeSliceEnd =
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd =
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde);
const auto K1 = GemmK1;
const auto K0 = K / K1;
// weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilda),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilda),
make_freeze_transform(IXTilda),
make_freeze_transform(IYTilde),
make_freeze_transform(IXTilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
@@ -163,25 +163,25 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilda),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
@@ -198,17 +198,17 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
#if 1
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
@@ -224,24 +224,24 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilda, WTilda),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilda),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_freeze_transform(IXTilda),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_freeze_transform(IYTilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(IXTilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
@@ -257,9 +257,9 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc,
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_pass_through_transform(C),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));

View File

@@ -10,8 +10,8 @@ namespace ck {
// A: out
// B: wei
// C: in
// Number of GEMMs = YTilda * XTilda
// GemmM = N * HTildaSlice * WTildaSlice
// Number of GEMMs = YTilde * XTilde
// GemmM = N * HTildeSlice * WTildeSlice
// GemmN = C
// GemmK = K * YDotSlice * XDotSlice
template <typename... Wei,
@@ -21,8 +21,8 @@ template <typename... Wei,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename IYTilda,
typename IXTilda,
typename IYTilde,
typename IXTilde,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
IYTilda i_ytilda,
IXTilda i_xtilda,
IYTilde i_ytilde,
IXTilde i_xtilde,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
@@ -72,32 +72,32 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilda);
const auto XDot = math::integer_divide_ceil(X, XTilda);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
const auto IHTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
const auto IWTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildaSliceEnd =
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildaSliceEnd =
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto IHTildeSliceEnd =
math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd =
math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const auto K1 = GemmK1;
const auto K0 = K / K1;
@@ -113,25 +113,25 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilda),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
@@ -148,41 +148,41 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
#if 1
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// B: weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilda),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilda),
make_freeze_transform(i_xtilda),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
@@ -225,24 +225,24 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilda, WTilda),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilda),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_freeze_transform(i_xtilda),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
@@ -258,8 +258,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));