diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 794c6f4e20..09801203ba 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -222,9 +222,6 @@ // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" #define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 -// workaround: conv crash when K, C is even -#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1 - // workaround: compiler crash when compiling recursive lambda #define CK_WORKAROUND_SWDEV_275126 1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 1cd1f16245..6e74899706 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -1299,13 +1299,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { -// workaround: disable when K, C is even -#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN - if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0) - { - return false; - } -#endif // check if it's 1x1, stride=1 pad = 0 conv for(int i = 0; i < NDimSpatial; i++) { diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index bd3ab10802..efc7f20cdc 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -192,7 +192,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -210,7 +210,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -218,9 +218,17 @@ struct TransformConvBwdWeightToGemm const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + // Padd + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmm_gemmn_pad_grid_desc); } else { @@ -240,7 +248,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -279,7 +287,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -288,26 +296,6 @@ struct TransformConvBwdWeightToGemm make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); // Padd - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmM, PadGemmM), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmN, PadGemmN), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto wei_gemmm_gemmn_pad_grid_desc = transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc, make_tuple(make_right_pad_transform(GemmM, PadGemmM), @@ -315,8 +303,8 @@ struct TransformConvBwdWeightToGemm make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmm_gemmn_pad_grid_desc); } } @@ -392,7 +380,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -407,13 +395,21 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + // Padd + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_grid_desc); + wei_gemmm_gemmn_pad_grid_desc); } else { @@ -428,7 +424,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -469,31 +465,11 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // Padd - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmM, PadGemmM), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmN, PadGemmN), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto wei_gemmm_gemmn_pad_grid_desc = transform_tensor_descriptor(wei_grid_desc, make_tuple(make_right_pad_transform(GemmM, PadGemmM), @@ -501,8 +477,8 @@ struct TransformConvBwdWeightToGemm make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmm_gemmn_pad_grid_desc); } } @@ -585,7 +561,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -600,13 +576,21 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + // Padd + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_grid_desc); + wei_gemmm_gemmn_pad_grid_desc); } else { @@ -621,7 +605,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -671,31 +655,11 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // Padd - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmM, PadGemmM), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmN, PadGemmN), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto wei_gemmm_gemmn_pad_grid_desc = transform_tensor_descriptor(wei_grid_desc, make_tuple(make_right_pad_transform(GemmM, PadGemmM), @@ -703,8 +667,8 @@ struct TransformConvBwdWeightToGemm make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmm_gemmn_pad_grid_desc); } } // function end diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index b72ddb8243..e410f06190 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -374,7 +374,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -390,13 +390,21 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + // Padd + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_grid_desc); + wei_gemmm_gemmn_pad_grid_desc); } else { @@ -412,7 +420,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -453,29 +461,11 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // Padd - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), - make_right_pad_transform(GemmM, PadGemmM), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), - make_right_pad_transform(GemmN, PadGemmN), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - const auto wei_gemmm_gemmn_pad_grid_desc = transform_tensor_descriptor(wei_grid_desc, make_tuple(make_right_pad_transform(GemmM, PadGemmM), @@ -483,8 +473,8 @@ struct TransformConvBwdWeightToGemmV2 make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmm_gemmn_pad_grid_desc); } @@ -562,7 +552,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -578,13 +568,21 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + // Padd + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_grid_desc); + wei_gemmm_gemmn_pad_grid_desc); } else { @@ -600,7 +598,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -650,29 +648,11 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // Padd - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), - make_right_pad_transform(GemmM, PadGemmM), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), - make_right_pad_transform(GemmN, PadGemmN), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - const auto wei_gemmm_gemmn_pad_grid_desc = transform_tensor_descriptor(wei_grid_desc, make_tuple(make_right_pad_transform(GemmM, PadGemmM), @@ -680,8 +660,8 @@ struct TransformConvBwdWeightToGemmV2 make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmm_gemmn_pad_grid_desc); } } @@ -765,7 +745,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -781,13 +761,21 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + // Padd + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_grid_desc); + wei_gemmm_gemmn_pad_grid_desc); } else { @@ -803,7 +791,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -868,29 +856,11 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // Padd - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), - make_right_pad_transform(GemmM, PadGemmM), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), - make_right_pad_transform(GemmN, PadGemmN), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - const auto wei_gemmm_gemmn_pad_grid_desc = transform_tensor_descriptor(wei_grid_desc, make_tuple(make_right_pad_transform(GemmM, PadGemmM), @@ -898,8 +868,8 @@ struct TransformConvBwdWeightToGemmV2 make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmm_gemmn_pad_grid_desc); } } // function end