[CK Tile] Grouped conv fwd splitn support (#2776)

## What's New
  Add Split-N support for grouped convolution forward to handle tensors >2GB by splitting the batch dimension.

  ## Bug Fix
  Fixed 32-bit integer overflow that caused crashes with 6+ splits:
  - Use `long_index_t` for batch offset calculations
  - Remove redundant GemmM initialization in constructors

  ## How It Works
  - Automatically splits batch dimension when tensor exceeds 2GB
  - Uses grid.z dimension for parallel processing of splits
  - Each split processes a subset of batches independently

  ## Testing
  Verified with tile_example_grouped_conv_fwd:
  - n=3000 (6 splits) ✓
  - n=3500 (7 splits) ✓
  - n=10480 (40 splits) ✓
This commit is contained in:
JH-Leon-KIM-AMD
2025-09-16 16:56:11 +03:00
committed by GitHub
parent 59cb906482
commit 804065a36b
2 changed files with 135 additions and 39 deletions

View File

@@ -23,7 +23,8 @@ struct GroupedConvFwdKernelArgs
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
GroupedConvTraitsType_::ConvSpecialization>;
GroupedConvTraitsType_::ConvSpecialization,
true>; // Split N enabled
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
template <
@@ -56,7 +57,7 @@ struct GroupedConvFwdKernelArgs
k_batch = args.k_batch;
GemmM = args.N_ * args.output_spatial_lengths_[0];
// GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0];
GemmBatch = args.G_;
@@ -94,6 +95,19 @@ struct GroupedConvFwdKernelArgs
1,
std::multiplies<index_t>());
group_stride_c = args.K_;
// Initialize Split-N support fields for 1D convolution (NWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
// Calculate batch strides for NWGC layout
input_batch_stride = args.C_ * args.input_spatial_lengths_[0];
output_batch_stride = args.K_ * args.output_spatial_lengths_[0];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0];
}
template <
@@ -133,7 +147,7 @@ struct GroupedConvFwdKernelArgs
k_batch = args.k_batch;
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
// Note: GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
GemmBatch = args.G_;
@@ -171,6 +185,21 @@ struct GroupedConvFwdKernelArgs
1,
std::multiplies<index_t>());
group_stride_c = args.K_;
// Initialize Split-N support fields for 2D convolution (NHWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
// Calculate batch strides for NHWGC layout
input_batch_stride =
args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
output_batch_stride =
args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
}
template <
@@ -217,8 +246,7 @@ struct GroupedConvFwdKernelArgs
k_batch = args.k_batch;
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
args.output_spatial_lengths_[2];
// Note: GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
args.filter_spatial_lengths_[2];
@@ -257,6 +285,22 @@ struct GroupedConvFwdKernelArgs
1,
std::multiplies<index_t>());
group_stride_c = args.K_;
// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
// Calculate batch strides for NDHWGC layout
input_batch_stride = args.C_ * args.input_spatial_lengths_[0] *
args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2];
output_batch_stride = args.K_ * args.output_spatial_lengths_[0] *
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
args.output_spatial_lengths_[2];
}
using AGridDescMK = remove_cvref_t<
@@ -297,6 +341,13 @@ struct GroupedConvFwdKernelArgs
long_index_t group_stride_a;
long_index_t group_stride_b;
long_index_t group_stride_c;
// Split-N support fields - initialize to safe defaults
index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
index_t n_per_split = 1; // Batches per split (N_ from transformer)
index_t original_n = 1; // Original batch size before splitting
index_t input_batch_stride = 0; // Stride to next batch in input tensor
index_t output_batch_stride = 0; // Stride to next batch in output tensor
};
/// @brief The Grouped Convolution Forward kernel template.
@@ -392,10 +443,10 @@ struct GroupedConvolutionForwardKernel
// clang-format on
}
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
{
return dim3(
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits);
}
CK_TILE_HOST static auto BlockSize()
@@ -430,6 +481,17 @@ struct GroupedConvolutionForwardKernel
}
}
// Check Split-K and Split-N conflict (both use blockIdx.z)
if(kargs.k_batch > 1 && kargs.n_splits > 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!");
}
return false;
}
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
@@ -768,10 +830,26 @@ struct GroupedConvolutionForwardKernel
const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);
// options
const InDataType* a_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
OutDataType* c_ptr = static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c;
// Split-N handling: Get which split this workgroup handles
const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z);
// Calculate batch offset for this split
const index_t batch_offset = __builtin_amdgcn_readfirstlane(blockIdZ * kargs.n_per_split);
// Calculate memory offsets for this split
const long_index_t input_batch_offset = static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.input_batch_stride);
const long_index_t output_batch_offset =
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.output_batch_stride);
// Adjust pointers: combine group offset and batch offset
const InDataType* a_ptr =
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
group_offset_b; // No batch offset for weights!
OutDataType* c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];