mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
fixed faulty padding API calls (#8)
This commit is contained in:
@@ -20,8 +20,8 @@ template <index_t GridSize,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
typename InputLeftPads,
|
||||
typename InputRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
@@ -98,8 +98,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Y, X>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>,
|
||||
true>{}),
|
||||
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
@@ -121,16 +120,14 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_hop_wop_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Pad<Sequence<Ho, Wo>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<right_pad_ho, right_pad_wo>,
|
||||
true>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<right_pad_ho, right_pad_wo>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_hop_wop_global_desc,
|
||||
@@ -154,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, LeftPads, RightPads, true>{}),
|
||||
Pad<Sequence<Hi, Wi>, InputLeftPads, InputRightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user