Add padding to 1x1Stride1Pad0 conv specialization (grouped conv bwd weight) (#2610)

* Add padding 1x1Stride1Pad0 conv specialization

* Add gridwise checks for conv cshufflev3

* Merge padding with previous transforms

* Apply transform changes for padding to default specialization as well

---------

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
Enrico Degregori
2025-08-05 15:23:19 +02:00
committed by GitHub
parent cbfecf8d7a
commit 2203b0ddfe
5 changed files with 290 additions and 168 deletions

View File

@@ -331,9 +331,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
@@ -1299,13 +1299,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// workaround: disable when K, C is even
#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN
if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0)
{
return false;
}
#endif
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{
@@ -1330,7 +1323,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
// Gridwise GEMM size
return true;
return GridwiseGemm::CheckValidity(gemm_arg);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override