diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index 7296e4faaa..18223c78f7 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -11,6 +11,8 @@ namespace ck { namespace tensor_operation { namespace device { +#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1 + template ()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 934dc7ee8e..987a1e273a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -671,6 +671,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -683,6 +684,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle grid_size); } else +#endif { k_batch_ = split_k; } @@ -939,6 +941,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index b361409e38..22fc13bae4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -553,6 +553,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -565,6 +566,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle grid_size); } else +#endif { k_batch_ = split_k; } @@ -934,6 +936,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 8bf188be2e..735eebbdf6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -524,6 +524,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -549,6 +550,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } } else +#endif { k_batch_ = split_k; } @@ -1275,6 +1277,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp index 2a9421fcd1..354d1fc23b 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp @@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test // clang-format on ck::utils::conv::ConvParam conv_param; - std::vector split_ks{-1, 2}; + ck::index_t split_k_ = 2; template bool Run() @@ -96,30 +96,24 @@ class TestGroupedConvndBwdWeight : public ::testing::Test auto conv = GroupedConvBwdWeightDeviceInstance{}; - bool is_supported = true; - - for(const auto split_k : split_ks) - { - auto argument = conv.MakeArgument(nullptr, - nullptr, - nullptr, - input_lengths, - input_strides, - filter_lengths, - weights_strides, - output_lengths, - output_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}, - split_k); - is_supported &= conv.IsSupportedArgument(argument); - } - return is_supported; + auto argument = conv.MakeArgument(nullptr, + nullptr, + nullptr, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + split_k_); + return conv.IsSupportedArgument(argument); } }; @@ -183,3 +177,12 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, VectorLoadCheck) is_supported = this->template Run<2>(); EXPECT_FALSE(is_supported); } + +TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce) +{ + // Supported version but with auto deduce and single stage + this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + this->split_k_ = -1; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +}