Merge commit '2203b0ddfe06f4f9f5126e54e78697dfb16118d4' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-05 13:24:51 +00:00
parent 54502bec81
commit e2418402f8
5 changed files with 290 additions and 168 deletions

View File

@@ -331,9 +331,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
@@ -1299,13 +1299,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// workaround: disable when K, C is even
#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN
if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0)
{
return false;
}
#endif
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{
@@ -1330,7 +1323,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
// Gridwise GEMM size
return true;
return GridwiseGemm::CheckValidity(gemm_arg);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override