From bc3a91d23fd8f526eec04fc1d91a4badf119b0ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Fri, 17 Oct 2025 15:36:39 +0300 Subject: [PATCH] Fixed handling of split-K autodeduce argument for grouped convolution (#3024) * Fix handling of split-K autodeduce argument. * Fix clang formatting. * Test fix. * Fix clang formatting. [ROCm/composable_kernel commit: 7e44b845b5dd4bcc28d55b4b2764e2be6418a35a] --- ...rd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 6 +++++ ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 8 ++++++ ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 6 +++++ ..._grouped_convnd_bwd_data_interface_xdl.cpp | 27 +++++++++++++++++-- 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index ff652ebefb..febb037157 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -689,6 +689,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ return false; } + // Split-K autodeduction is not supported + if(arg.k_batch_ < 1) + { + return false; + } + // Gridwise GEMM size return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 47832e2153..4672de3504 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1523,6 +1523,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return false; } } + else + { + // Split-K autodeduction is not supported. + if(arg.k_batch_ < 1) + { + return false; + } + } const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index f6ec0908eb..d5d48777a0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -688,6 +688,12 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static bool IsSupportedArgument(const Argument& arg) { + // Split-K autodeduction is not supported + if(arg.k_batch_ < 1) + { + return false; + } + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) { return false; diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_xdl.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_xdl.cpp index 01f4260c43..7903c17b22 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_xdl.cpp @@ -47,10 +47,11 @@ class TestGroupedConvndBwdData : public ::testing::Test // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; + < NDimSpatial, OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on ck::utils::conv::ConvParam conv_param; + ck::index_t split_k{1}; template bool Run() @@ -112,7 +113,8 @@ class TestGroupedConvndBwdData : public ::testing::Test input_right_pads, Pass{}, Pass{}, - Pass{}); + Pass{}, + split_k); return conv.IsSupportedArgument(argument); } }; @@ -176,3 +178,24 @@ TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck) is_supported = this->template Run<2>(); EXPECT_FALSE(is_supported); } + +TYPED_TEST(TestGroupedConvndBwdDataDefault, SplitK) +{ + if(ck::is_xdl_supported()) + { + // SplitK = 1 + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + this->split_k = 1; + bool is_supported = this->template Run<2>(); + EXPECT_TRUE(is_supported); + + // Split-K autodeduce + this->split_k = -1; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + } + else + { + GTEST_SKIP() << "XDL ops not supported on this device"; + } +}