diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index e38f166c1a..26e4787949 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -222,6 +222,9 @@ // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" #define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 +// workaround: conv crash when K, C is even +#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1 + // workaround: compiler crash when compiling recursive lambda #define CK_WORKAROUND_SWDEV_275126 1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index dd5b97096d..869457a99e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -1206,6 +1206,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { +// workaround: disable when K, C is even +#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN + if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0) + { + return false; + } +#endif // check if it's 1x1, stride=1 pad = 0 conv for(int i = 0; i < NDimSpatial; i++) { diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 21f2cb5ce6..95a0a09414 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -188,6 +188,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D) TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) { this->conv_params.clear(); + this->conv_params.push_back({2, 2, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});