From ac7ecfe7aa4b889bc30060df72e4d7cae92c495f Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Fri, 16 May 2025 10:18:47 +0200 Subject: [PATCH] Disable conv for Filter1x1Stride1Pad0 when K or C is even (#2186) [ROCm/composable_kernel commit: fa3c6811d8e81096f52779bf0877777bf405d241] --- include/ck/ck.hpp | 3 +++ .../device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 7 +++++++ .../test_grouped_convnd_bwd_weight.cpp | 1 + 3 files changed, 11 insertions(+) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index e38f166c1a..26e4787949 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -222,6 +222,9 @@ // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" #define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 +// workaround: conv crash when K, C is even +#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1 + // workaround: compiler crash when compiling recursive lambda #define CK_WORKAROUND_SWDEV_275126 1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index dd5b97096d..869457a99e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -1206,6 +1206,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { +// workaround: disable when K, C is even +#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN + if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0) + { + return false; + } +#endif // check if it's 1x1, stride=1 pad = 0 conv for(int i = 0; i < NDimSpatial; i++) { diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 21f2cb5ce6..95a0a09414 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -188,6 +188,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D) TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) { this->conv_params.clear(); + this->conv_params.push_back({2, 2, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});