Add padding to 1x1Stride1Pad0 conv specialization (grouped conv bwd weight) (#2675)

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

[ROCm/composable_kernel commit: a6f4029276]
This commit is contained in:
Enrico Degregori
2025-08-14 00:21:09 +02:00
committed by GitHub
parent 68eecd0d5e
commit ec1c249b63
4 changed files with 90 additions and 166 deletions

View File

@@ -222,9 +222,6 @@
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
// workaround: conv crash when K, C is even
#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1
// workaround: compiler crash when compiling recursive lambda
#define CK_WORKAROUND_SWDEV_275126 1

View File

@@ -1299,13 +1299,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// workaround: disable when K, C is even
#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN
if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0)
{
return false;
}
#endif
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{

View File

@@ -192,7 +192,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -210,7 +210,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -218,9 +218,17 @@ struct TransformConvBwdWeightToGemm
const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
// Padd
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
wei_gemmm_gemmn_pad_grid_desc);
}
else
{
@@ -240,7 +248,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -279,7 +287,7 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -288,26 +296,6 @@ struct TransformConvBwdWeightToGemm
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
// Padd
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
@@ -315,8 +303,8 @@ struct TransformConvBwdWeightToGemm
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
}
@@ -392,7 +380,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -407,13 +395,21 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// Padd
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_grid_desc);
wei_gemmm_gemmn_pad_grid_desc);
}
else
{
@@ -428,7 +424,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -469,31 +465,11 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// Padd
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
@@ -501,8 +477,8 @@ struct TransformConvBwdWeightToGemm
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
}
@@ -585,7 +561,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -600,13 +576,21 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// Padd
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_grid_desc);
wei_gemmm_gemmn_pad_grid_desc);
}
else
{
@@ -621,7 +605,7 @@ struct TransformConvBwdWeightToGemm
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
@@ -671,31 +655,11 @@ struct TransformConvBwdWeightToGemm
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// Padd
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
@@ -703,8 +667,8 @@ struct TransformConvBwdWeightToGemm
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
} // function end

View File

@@ -374,7 +374,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -390,13 +390,21 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_grid_desc);
wei_gemmm_gemmn_pad_grid_desc);
}
else
{
@@ -412,7 +420,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -453,29 +461,11 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
@@ -483,8 +473,8 @@ struct TransformConvBwdWeightToGemmV2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
@@ -562,7 +552,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -578,13 +568,21 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_grid_desc);
wei_gemmm_gemmn_pad_grid_desc);
}
else
{
@@ -600,7 +598,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -650,29 +648,11 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
@@ -680,8 +660,8 @@ struct TransformConvBwdWeightToGemmV2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
}
@@ -765,7 +745,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -781,13 +761,21 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_grid_desc);
wei_gemmm_gemmn_pad_grid_desc);
}
else
{
@@ -803,7 +791,7 @@ struct TransformConvBwdWeightToGemmV2
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_right_pad_transform(GemmM, PadGemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -868,29 +856,11 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
@@ -898,8 +868,8 @@ struct TransformConvBwdWeightToGemmV2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
} // function end