From 2e22c67ce68fc55f210ffa3f9d8fd1365464e240 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 20 Mar 2026 16:45:45 +0100 Subject: [PATCH] [CK][CK Tile] Improve access for merged groups and remove modulo from xor (#5334) ## Motivation [CK][CK Tile] Improve access for merged groups and remove modulo from xor ## Technical Details - add template parameter to xor if modulo is needed. We don't need modulo for merged groups - use access by m for merged groups for a tensor - ## Test Plan test_grouped_convnd_fwd_tile ## Test Result passed locally ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../core/algorithm/coordinate_transform.hpp | 21 ++++--- .../grouped_convolution_forward_kernel.hpp | 59 +++++++++++++++---- .../utils/grouped_convolution_utils.hpp | 4 +- .../transform_conv_bwd_weight_to_gemm.hpp | 30 ++++++---- .../utils/transform_conv_fwd_to_gemm.hpp | 24 +++++--- 5 files changed, 96 insertions(+), 42 deletions(-) diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 30c93b8f00..af43cd3399 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1298,7 +1298,7 @@ CK_TILE_HOST_DEVICE static void print(const modulo& m) } // 2D XOR, NOTE: "xor" is a keyword -template +template struct xor_t : public base_transform<2, 2> { static constexpr auto type_enum = coord_transform_enum::xor_t; @@ -1330,8 +1330,15 @@ struct xor_t : public base_transform<2, 2> idx_low(number<0>{}) = idx_up[number<0>{}]; - idx_low(number<1>{}) = - idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]); + if constexpr(ApplyModulo) + { + idx_low(number<1>{}) = + idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]); + } + else + { + idx_low(number<1>{}) = idx_up[number<1>{}] ^ (idx_up[number<0>{}]); + } } template @@ -1382,8 +1389,8 @@ struct xor_t : public base_transform<2, 2> } }; -template -CK_TILE_HOST_DEVICE static void print(const xor_t& x) +template +CK_TILE_HOST_DEVICE static void print(const xor_t& x) { printf("xor_t{"); printf("up_lengths_: "); @@ -1737,10 +1744,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus, return modulo{modulus, up_length}; } -template +template CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths) { - return xor_t{low_lengths}; + return xor_t{low_lengths}; } template diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index bbbd248787..1eb0ee2022 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -513,7 +513,9 @@ struct GroupedConvolutionForwardKernel static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK, "Not supported!"); - static_assert(std::is_same_v, "Not supported!"); + static_assert(std::is_same_v || + GroupedConvTraitsType_::NumGroupsToMerge > 1, + "Not supported!"); static_assert(std::is_same_v, "Not supported!"); static_assert(std::is_same_v, "Not supported!"); static_assert(GroupedConvTraitsType_::ExplicitGemm == false || @@ -885,20 +887,51 @@ struct GroupedConvolutionForwardKernel CK_TILE_DEVICE static auto MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m) { - // Step 1: Create tensor view - const auto& a_tensor_view = make_tensor_view(a_ptr, a_desc); + if constexpr(GroupedConvTraitsType_::NumGroupsToMerge == 1) + { + // Access by K + // Step 1: Create tensor view + const auto& a_tensor_view = make_tensor_view(a_ptr, a_desc); - // Step 2: Create padded view - const auto& a_pad_view = pad_tensor_view( - a_tensor_view, - make_tuple(number{}, number{}), - sequence{}); + // Step 2: Create padded view + const auto& a_pad_view = + pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); - // Step 3: Create tile window - return make_tile_window( - a_pad_view, - make_tuple(number{}, number{}), - {block_idx_m, 0}); + // Step 3: Create tile window + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {block_idx_m, 0}); + } + else + { + // Access by M + const auto a_desc_reversed = transform_tensor_descriptor( + a_desc, + make_tuple(make_pass_through_transform(a_desc.get_length(I0)), + make_pass_through_transform(a_desc.get_length(I1))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + // Step 1: Create tensor view + const auto& a_tensor_view = + make_tensor_view(a_ptr, a_desc_reversed); + + // Step 2: Create padded view + const auto& a_pad_view = + pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, block_idx_m}); + } } template diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 5b00e53af8..2efb435d5b 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -108,7 +108,9 @@ struct GroupedConvTraits using OutLayout = OutLayout_; // Forward Gemm Layouts - using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; + using AsLayoutFwd = std::conditional_t; using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor; using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; // Backward Data Gemm Layouts diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp index 0b290a474c..9208be4929 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp @@ -518,10 +518,12 @@ struct TransformConvBwdWeightToGemm NumGroupsToMerge == 32 || NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), - make_pass_through_transform(K_), - make_pass_through_transform(X_), - make_pass_through_transform(C_)), + make_tuple( + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_), + make_pass_through_transform(X_), + make_pass_through_transform(C_)), make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}), make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{})); // Merge To M, N @@ -652,10 +654,12 @@ struct TransformConvBwdWeightToGemm NumGroupsToMerge == 32 || NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), - make_pass_through_transform(K_), - make_pass_through_transform(Y_ * X_), - make_pass_through_transform(C_)), + make_tuple( + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_), + make_pass_through_transform(Y_ * X_), + make_pass_through_transform(C_)), make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}), make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{})); // Merge To M, N @@ -788,10 +792,12 @@ struct TransformConvBwdWeightToGemm NumGroupsToMerge == 32 || NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), - make_pass_through_transform(K_), - make_pass_through_transform(Z_ * Y_ * X_), - make_pass_through_transform(C_)), + make_tuple( + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_), + make_pass_through_transform(Z_ * Y_ * X_), + make_pass_through_transform(C_)), make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}), make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{})); // Merge To M, N diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index 54fec53d56..46e3033ef1 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -1363,9 +1363,11 @@ struct TransformConvFwdToGemm NumGroupsToMerge == 32 || NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_pass_through_transform(NDoHoWo), - make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), - make_pass_through_transform(K_)), + make_tuple( + make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{})); // Merge To M, N @@ -1429,9 +1431,11 @@ struct TransformConvFwdToGemm NumGroupsToMerge == 32 || NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_pass_through_transform(NDoHoWo), - make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), - make_pass_through_transform(K_)), + make_tuple( + make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{})); // Merge To M, N @@ -1496,9 +1500,11 @@ struct TransformConvFwdToGemm NumGroupsToMerge == 32 || NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_pass_through_transform(NDoHoWo), - make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), - make_pass_through_transform(K_)), + make_tuple( + make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{})); // Merge To M, N