diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index 15c56f9261..1cff9b5733 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -27,7 +27,8 @@ struct GroupedConvBwdDataKernelArgs GroupedConvTraitsType_::ConvSpecialization, GroupedConvTraitsType_::VectorSizeA, GroupedConvTraitsType_::VectorSizeB, - GroupedConvTraitsType_::VectorSizeC>; + GroupedConvTraitsType_::VectorSizeC, + true>; // Split N enabled static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; static constexpr auto I0 = number<0>(); @@ -121,6 +122,11 @@ struct GroupedConvBwdDataKernelArgs grid_size_ += grid_size_grp; + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + ++gemm_count; } group_stride_a = args.K_; // A: Out NWGK @@ -131,6 +137,9 @@ struct GroupedConvBwdDataKernelArgs std::multiplies()); // B: Wei GKXC group_stride_c = args.C_; // C: In NWGC + input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0]; + output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0]; + GemmBatch = args.G_; } @@ -237,6 +246,11 @@ struct GroupedConvBwdDataKernelArgs grid_size_ += grid_size_grp; + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + ++gemm_count; } } @@ -248,6 +262,11 @@ struct GroupedConvBwdDataKernelArgs std::multiplies()); // B: Wei GKXC group_stride_c = args.C_; // C: In NWGC + input_batch_stride = + args.C_ * args.G_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1]; + output_batch_stride = + args.K_ * args.G_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + GemmBatch = args.G_; } @@ -369,6 +388,11 @@ struct GroupedConvBwdDataKernelArgs grid_size_ += grid_size_grp; + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + ++gemm_count; } } @@ -382,6 +406,11 @@ struct GroupedConvBwdDataKernelArgs std::multiplies()); // B: Wei GKXC group_stride_c = args.C_; // C: In NWGC + input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0] * + args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2]; + output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0] * + args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2]; + GemmBatch = args.G_; // C: In NWGC } @@ -425,6 +454,13 @@ struct GroupedConvBwdDataKernelArgs long_index_t group_stride_a; long_index_t group_stride_b; long_index_t group_stride_c; + + // Split-N support fields - initialize to safe defaults + index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2) + index_t n_per_split = 1; // Batches per split (N_ from transformer) + index_t original_n = 1; // Original batch size before splitting + index_t input_batch_stride = 0; // Stride to next batch in input tensor + index_t output_batch_stride = 0; // Stride to next batch in output tensor }; /// @brief The Grouped Convolution Backward Data kernel template. @@ -527,7 +563,7 @@ struct GroupedConvolutionBackwardDataKernel CK_TILE_HOST static auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized& kargs) { // enable batched grouped gemm - return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch); + return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.n_splits * kargs.k_batch); } CK_TILE_HOST static constexpr auto BlockSize() @@ -943,11 +979,31 @@ struct GroupedConvolutionBackwardDataKernel const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY); const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY); + const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z); + + // SplitN + const index_t split_n_idx = __builtin_amdgcn_readfirstlane(blockIdZ / kargs.k_batch); + const index_t split_n_offset = + __builtin_amdgcn_readfirstlane(split_n_idx * kargs.n_per_split); + + const long_index_t output_batch_offset = + static_cast(split_n_offset) * + static_cast(kargs.output_batch_stride); + const long_index_t input_batch_offset = static_cast(split_n_offset) * + static_cast(kargs.input_batch_stride); + + // SplitK + // TODO: Implement SplitK support + // const index_t split_k_idx = + // __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch); + // options // conv_bwd_data = Out * Weight = In - const OutDataType* a_ptr = static_cast(kargs.out_ptr) + group_offset_a; + const OutDataType* a_ptr = + static_cast(kargs.out_ptr) + group_offset_a + output_batch_offset; const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + group_offset_b; - InDataType* c_ptr = static_cast(kargs.in_ptr) + group_offset_c; + InDataType* c_ptr = + static_cast(kargs.in_ptr) + group_offset_c + input_batch_offset; // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp index 359214d3be..a00ea37979 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp @@ -27,7 +27,7 @@ struct TransformConvBwdDataToGemm static constexpr auto I3 = number<3>{}; static constexpr auto I4 = number<4>{}; static constexpr auto I5 = number<5>{}; -#if 0 // TODO: Enable these functionalities + template static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, const ConvDimsType& strides, @@ -44,25 +44,45 @@ struct TransformConvBwdDataToGemm } template - static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, - const ConvDimsType& a_g_n_c_wis_strides, - const ConvDimsType& c_g_n_k_wos_lengths, - const ConvDimsType& c_g_n_k_wos_strides) + static IndexType GetSplitedNSize(const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& a_g_n_c_wis_lengths) { + + // Calculate strides internally assuming contiguous memory layout + ConvDimsType c_g_n_k_wos_strides, a_g_n_c_wis_strides; + const index_t num_dims = c_g_n_k_wos_strides.size(); + + // Calculate strides for input tensor (innermost to outermost), + // Don't include outermost dimension G since it's gemm batch. + a_g_n_c_wis_strides[num_dims - 1] = 1; + for(index_t i = num_dims - 2; i >= 1; i--) + { + a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1]; + } + + // Calculate strides for output tensor, + // Don't include outermost dimension G since it's gemm batch. + c_g_n_k_wos_strides[num_dims - 1] = 1; + for(index_t i = num_dims - 2; i >= 1; i--) + { + c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1]; + } + const long_index_t a_element_space_size = calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); const long_index_t c_element_space_size = calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); - const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), - c_element_space_size * sizeof(CDataType)); - constexpr long_index_t TwoGB = (long_index_t{1} << 31); + const long_index_t element_space_size = ck_tile::max( + a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType)); - const IndexType N = a_g_n_c_wis_lengths[I1]; + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = c_g_n_k_wos_lengths[I1]; if(element_space_size > TwoGB) { // Minimum divisor of N to not exceed 2GB - const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB); if(divisor <= static_cast(N)) { @@ -93,9 +113,12 @@ struct TransformConvBwdDataToGemm return N; } } -#endif public: + // Public getter methods for Split-N support + CK_TILE_HOST constexpr IndexType GetN() const { return N_; } + CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; } + CK_TILE_HOST constexpr TransformConvBwdDataToGemm() {} template @@ -103,6 +126,7 @@ struct TransformConvBwdDataToGemm TransformConvBwdDataToGemm(const TransformConvBwdDataToGemmBase& transform_conv_to_gemm_base) : G_{static_cast(transform_conv_to_gemm_base.G_)}, N_{static_cast(transform_conv_to_gemm_base.N_)}, + original_N_{static_cast(transform_conv_to_gemm_base.original_N_)}, Di_{static_cast(transform_conv_to_gemm_base.Di_)}, Hi_{static_cast(transform_conv_to_gemm_base.Hi_)}, Wi_{static_cast(transform_conv_to_gemm_base.Wi_)}, @@ -170,17 +194,18 @@ struct TransformConvBwdDataToGemm IdxYTilde_{I1}, IdxXTilde_{tildes[I0]} { -#if 0 // TODO: Enable these functionalities + + // Store original N + original_N_ = a_g_n_c_wis_lengths[I1]; + if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = a_g_n_c_wis_lengths[I1]; } -#endif GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_); XTilde_ = ConvStrideW_ / GcdStrideDilationW_; @@ -229,17 +254,19 @@ struct TransformConvBwdDataToGemm IdxYTilde_{tildes[I0]}, IdxXTilde_{tildes[I1]} { -#if 0 // TODO: Enable these functionalities + + // Store original N + original_N_ = a_g_n_c_wis_lengths[I1]; + if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = a_g_n_c_wis_lengths[I1]; } -#endif + GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_); GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_); XTilde_ = ConvStrideW_ / GcdStrideDilationW_; @@ -291,17 +318,19 @@ struct TransformConvBwdDataToGemm IdxYTilde_{tildes[I1]}, IdxXTilde_{tildes[I2]} { -#if 0 // TODO: Enable these functionalities + + // Store original N + original_N_ = a_g_n_c_wis_lengths[I1]; + if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = a_g_n_c_wis_lengths[I1]; } -#endif + GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_); GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_); GcdStrideDilationD_ = gcd(ConvStrideD_, ConvDilationD_); @@ -1068,7 +1097,7 @@ struct TransformConvBwdDataToGemm in_gemmmraw_gemmnraw_grid_desc); } - IndexType G_, N_; + IndexType G_, N_, original_N_; IndexType Di_, Hi_, Wi_; IndexType Do_, Ho_, Wo_; IndexType Z_, Y_, X_;