From 6980efa6fe2bcdd7eb3f9b581b6e26af89073823 Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Tue, 16 Sep 2025 16:56:11 +0300 Subject: [PATCH] [CK Tile] Grouped conv fwd splitn support (#2776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What's New Add Split-N support for grouped convolution forward to handle tensors >2GB by splitting the batch dimension. ## Bug Fix Fixed 32-bit integer overflow that caused crashes with 6+ splits: - Use `long_index_t` for batch offset calculations - Remove redundant GemmM initialization in constructors ## How It Works - Automatically splits batch dimension when tensor exceeds 2GB - Uses grid.z dimension for parallel processing of splits - Each split processes a subset of batches independently ## Testing Verified with tile_example_grouped_conv_fwd: - n=3000 (6 splits) ✓ - n=3500 (7 splits) ✓ - n=10480 (40 splits) ✓ [ROCm/composable_kernel commit: 804065a36b12abbb708ed65eba4513a5df59a25d] --- .../grouped_convolution_forward_kernel.hpp | 100 ++++++++++++++++-- .../utils/transform_conv_fwd_to_gemm.hpp | 74 ++++++++----- 2 files changed, 135 insertions(+), 39 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index cf4eca7a2d..6fcef5502e 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -23,7 +23,8 @@ struct GroupedConvFwdKernelArgs using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + GroupedConvTraitsType_::ConvSpecialization, + true>; // Split N enabled static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; template < @@ -56,7 +57,7 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0]; + // GemmM will be set after Split-N calculation GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0]; GemmBatch = args.G_; @@ -94,6 +95,19 @@ struct GroupedConvFwdKernelArgs 1, std::multiplies()); group_stride_c = args.K_; + + // Initialize Split-N support fields for 1D convolution (NWGC layout) + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + + // Calculate batch strides for NWGC layout + input_batch_stride = args.C_ * args.input_spatial_lengths_[0]; + output_batch_stride = args.K_ * args.output_spatial_lengths_[0]; + + // Update GemmM to use split N (not original N) + GemmM = n_per_split * args.output_spatial_lengths_[0]; } template < @@ -133,7 +147,7 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + // Note: GemmM will be set after Split-N calculation GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; GemmBatch = args.G_; @@ -171,6 +185,21 @@ struct GroupedConvFwdKernelArgs 1, std::multiplies()); group_stride_c = args.K_; + + // Initialize Split-N support fields for 2D convolution (NHWGC layout) + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + + // Calculate batch strides for NHWGC layout + input_batch_stride = + args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1]; + output_batch_stride = + args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + + // Update GemmM to use split N (not original N) + GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; } template < @@ -217,8 +246,7 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * - args.output_spatial_lengths_[2]; + // Note: GemmM will be set after Split-N calculation GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * args.filter_spatial_lengths_[2]; @@ -257,6 +285,22 @@ struct GroupedConvFwdKernelArgs 1, std::multiplies()); group_stride_c = args.K_; + + // Initialize Split-N support fields for 3D convolution (NDHWGC layout) + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + + // Calculate batch strides for NDHWGC layout + input_batch_stride = args.C_ * args.input_spatial_lengths_[0] * + args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2]; + output_batch_stride = args.K_ * args.output_spatial_lengths_[0] * + args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2]; + + // Update GemmM to use split N (not original N) + GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * + args.output_spatial_lengths_[2]; } using AGridDescMK = remove_cvref_t< @@ -297,6 +341,13 @@ struct GroupedConvFwdKernelArgs long_index_t group_stride_a; long_index_t group_stride_b; long_index_t group_stride_c; + + // Split-N support fields - initialize to safe defaults + index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2) + index_t n_per_split = 1; // Batches per split (N_ from transformer) + index_t original_n = 1; // Original batch size before splitting + index_t input_batch_stride = 0; // Stride to next batch in input tensor + index_t output_batch_stride = 0; // Stride to next batch in output tensor }; /// @brief The Grouped Convolution Forward kernel template. @@ -392,10 +443,10 @@ struct GroupedConvolutionForwardKernel // clang-format on } - CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) + CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) { return dim3( - TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); + TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits); } CK_TILE_HOST static auto BlockSize() @@ -430,6 +481,17 @@ struct GroupedConvolutionForwardKernel } } + // Check Split-K and Split-N conflict (both use blockIdx.z) + if(kargs.k_batch > 1 && kargs.n_splits > 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!"); + } + return false; + } + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; @@ -768,10 +830,26 @@ struct GroupedConvolutionForwardKernel const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); - // options - const InDataType* a_ptr = static_cast(kargs.in_ptr) + group_offset_a; - const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + group_offset_b; - OutDataType* c_ptr = static_cast(kargs.out_ptr) + group_offset_c; + // Split-N handling: Get which split this workgroup handles + const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z); + + // Calculate batch offset for this split + const index_t batch_offset = __builtin_amdgcn_readfirstlane(blockIdZ * kargs.n_per_split); + + // Calculate memory offsets for this split + const long_index_t input_batch_offset = static_cast(batch_offset) * + static_cast(kargs.input_batch_stride); + const long_index_t output_batch_offset = + static_cast(batch_offset) * + static_cast(kargs.output_batch_stride); + + // Adjust pointers: combine group offset and batch offset + const InDataType* a_ptr = + static_cast(kargs.in_ptr) + group_offset_a + input_batch_offset; + const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + + group_offset_b; // No batch offset for weights! + OutDataType* c_ptr = + static_cast(kargs.out_ptr) + group_offset_c + output_batch_offset; // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index c468ae4398..2663d8a494 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -24,7 +24,7 @@ struct TransformConvFwdToGemm static constexpr auto I3 = number<3>{}; static constexpr auto I4 = number<4>{}; static constexpr auto I5 = number<5>{}; -#if 0 // TODO: Enable these functionalities + template static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, const ConvDimsType& strides, @@ -42,24 +42,40 @@ struct TransformConvFwdToGemm template static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, - const ConvDimsType& a_g_n_c_wis_strides, - const ConvDimsType& c_g_n_k_wos_lengths, - const ConvDimsType& c_g_n_k_wos_strides) + const ConvDimsType& c_g_n_k_wos_lengths) { + // Calculate strides internally assuming contiguous memory layout + ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides; + const index_t num_dims = a_g_n_c_wis_lengths.size(); + + // Calculate strides for input tensor (innermost to outermost) + a_g_n_c_wis_strides[num_dims - 1] = 1; + for(index_t i = num_dims - 2; i >= 0; i--) + { + a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1]; + } + + // Calculate strides for output tensor + c_g_n_k_wos_strides[num_dims - 1] = 1; + for(index_t i = num_dims - 2; i >= 0; i--) + { + c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1]; + } + const long_index_t a_element_space_size = calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); const long_index_t c_element_space_size = calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); - const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), - c_element_space_size * sizeof(CDataType)); - constexpr long_index_t TwoGB = (long_index_t{1} << 31); + const long_index_t element_space_size = ck_tile::max( + a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB const IndexType N = a_g_n_c_wis_lengths[I1]; if(element_space_size > TwoGB) { // Minimum divisor of N to not exceed 2GB - const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB); if(divisor <= static_cast(N)) { @@ -70,7 +86,8 @@ struct TransformConvFwdToGemm { if(N % least_divisor == 0) { - return N / least_divisor; + IndexType result = N / least_divisor; + return result; } } // Not found, process one Convolution N per block @@ -90,9 +107,12 @@ struct TransformConvFwdToGemm return N; } } -#endif public: + // Public getter methods for Split-N support + CK_TILE_HOST constexpr IndexType GetN() const { return N_; } + CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; } + CK_TILE_HOST constexpr TransformConvFwdToGemm() {} template @@ -100,6 +120,7 @@ struct TransformConvFwdToGemm TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base) : G_{static_cast(transform_conv_fwd_to_gemm_base.G_)}, N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + original_N_{static_cast(transform_conv_fwd_to_gemm_base.original_N_)}, Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, @@ -168,18 +189,14 @@ struct TransformConvFwdToGemm std::is_same_v>); static_assert(std::is_same_v> || std::is_same_v>); -#if 0 // TODO: Enable these functionalities if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { N_ = c_g_n_k_wos_lengths[I1]; } -#endif - N_ = c_g_n_k_wos_lengths[I1]; } template >); static_assert(std::is_same_v> || std::is_same_v>); -#if 0 // TODO: Enable these functionalities + + // Store original N + original_N_ = c_g_n_k_wos_lengths[I1]; + if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = c_g_n_k_wos_lengths[I1]; + original_N_ = N_; } -#endif - N_ = c_g_n_k_wos_lengths[I1]; } template >); static_assert(std::is_same_v> || std::is_same_v>); -#if 0 // TODO: Enable these functionalities + + // Store original N before potential splitting + original_N_ = c_g_n_k_wos_lengths[I1]; + if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = original_N_; } -#endif - N_ = c_g_n_k_wos_lengths[I1]; } #if 0 // TODO: Enable these functionalities @@ -1417,7 +1435,7 @@ struct TransformConvFwdToGemm } } - IndexType G_, N_; + IndexType G_, N_, original_N_; IndexType Di_, Hi_, Wi_; IndexType Do_, Ho_, Wo_; IndexType Z_, Y_, X_;