mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK Tile] Grouped conv fwd splitn support (#2776)
## What's New Add Split-N support for grouped convolution forward to handle tensors >2GB by splitting the batch dimension. ## Bug Fix Fixed 32-bit integer overflow that caused crashes with 6+ splits: - Use `long_index_t` for batch offset calculations - Remove redundant GemmM initialization in constructors ## How It Works - Automatically splits batch dimension when tensor exceeds 2GB - Uses grid.z dimension for parallel processing of splits - Each split processes a subset of batches independently ## Testing Verified with tile_example_grouped_conv_fwd: - n=3000 (6 splits) ✓ - n=3500 (7 splits) ✓ - n=10480 (40 splits) ✓
This commit is contained in:
@@ -23,7 +23,8 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
|
||||
GroupedConvTraitsType_::ConvSpecialization>;
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
true>; // Split N enabled
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
template <
|
||||
@@ -56,7 +57,7 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0];
|
||||
// GemmM will be set after Split-N calculation
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0];
|
||||
GemmBatch = args.G_;
|
||||
@@ -94,6 +95,19 @@ struct GroupedConvFwdKernelArgs
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
|
||||
// Initialize Split-N support fields for 1D convolution (NWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
// Calculate batch strides for NWGC layout
|
||||
input_batch_stride = args.C_ * args.input_spatial_lengths_[0];
|
||||
output_batch_stride = args.K_ * args.output_spatial_lengths_[0];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0];
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -133,7 +147,7 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
// Note: GemmM will be set after Split-N calculation
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
|
||||
GemmBatch = args.G_;
|
||||
@@ -171,6 +185,21 @@ struct GroupedConvFwdKernelArgs
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
|
||||
// Initialize Split-N support fields for 2D convolution (NHWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
// Calculate batch strides for NHWGC layout
|
||||
input_batch_stride =
|
||||
args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
|
||||
output_batch_stride =
|
||||
args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -217,8 +246,7 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
|
||||
args.output_spatial_lengths_[2];
|
||||
// Note: GemmM will be set after Split-N calculation
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
|
||||
args.filter_spatial_lengths_[2];
|
||||
@@ -257,6 +285,22 @@ struct GroupedConvFwdKernelArgs
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
|
||||
// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
// Calculate batch strides for NDHWGC layout
|
||||
input_batch_stride = args.C_ * args.input_spatial_lengths_[0] *
|
||||
args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2];
|
||||
output_batch_stride = args.K_ * args.output_spatial_lengths_[0] *
|
||||
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
|
||||
args.output_spatial_lengths_[2];
|
||||
}
|
||||
|
||||
using AGridDescMK = remove_cvref_t<
|
||||
@@ -297,6 +341,13 @@ struct GroupedConvFwdKernelArgs
|
||||
long_index_t group_stride_a;
|
||||
long_index_t group_stride_b;
|
||||
long_index_t group_stride_c;
|
||||
|
||||
// Split-N support fields - initialize to safe defaults
|
||||
index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
|
||||
index_t n_per_split = 1; // Batches per split (N_ from transformer)
|
||||
index_t original_n = 1; // Original batch size before splitting
|
||||
index_t input_batch_stride = 0; // Stride to next batch in input tensor
|
||||
index_t output_batch_stride = 0; // Stride to next batch in output tensor
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Forward kernel template.
|
||||
@@ -392,10 +443,10 @@ struct GroupedConvolutionForwardKernel
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
return dim3(
|
||||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
|
||||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
@@ -430,6 +481,17 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
}
|
||||
|
||||
// Check Split-K and Split-N conflict (both use blockIdx.z)
|
||||
if(kargs.k_batch > 1 && kargs.n_splits > 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
|
||||
@@ -768,10 +830,26 @@ struct GroupedConvolutionForwardKernel
|
||||
const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
// options
|
||||
const InDataType* a_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a;
|
||||
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
|
||||
OutDataType* c_ptr = static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c;
|
||||
// Split-N handling: Get which split this workgroup handles
|
||||
const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
// Calculate batch offset for this split
|
||||
const index_t batch_offset = __builtin_amdgcn_readfirstlane(blockIdZ * kargs.n_per_split);
|
||||
|
||||
// Calculate memory offsets for this split
|
||||
const long_index_t input_batch_offset = static_cast<long_index_t>(batch_offset) *
|
||||
static_cast<long_index_t>(kargs.input_batch_stride);
|
||||
const long_index_t output_batch_offset =
|
||||
static_cast<long_index_t>(batch_offset) *
|
||||
static_cast<long_index_t>(kargs.output_batch_stride);
|
||||
|
||||
// Adjust pointers: combine group offset and batch offset
|
||||
const InDataType* a_ptr =
|
||||
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
|
||||
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
|
||||
group_offset_b; // No batch offset for weights!
|
||||
OutDataType* c_ptr =
|
||||
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
@@ -24,7 +24,7 @@ struct TransformConvFwdToGemm
|
||||
static constexpr auto I3 = number<3>{};
|
||||
static constexpr auto I4 = number<4>{};
|
||||
static constexpr auto I5 = number<5>{};
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
|
||||
template <typename ConvDimsType>
|
||||
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
|
||||
const ConvDimsType& strides,
|
||||
@@ -42,24 +42,40 @@ struct TransformConvFwdToGemm
|
||||
|
||||
template <typename ConvDimsType>
|
||||
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& a_g_n_c_wis_strides,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_strides)
|
||||
const ConvDimsType& c_g_n_k_wos_lengths)
|
||||
{
|
||||
// Calculate strides internally assuming contiguous memory layout
|
||||
ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides;
|
||||
const index_t num_dims = a_g_n_c_wis_lengths.size();
|
||||
|
||||
// Calculate strides for input tensor (innermost to outermost)
|
||||
a_g_n_c_wis_strides[num_dims - 1] = 1;
|
||||
for(index_t i = num_dims - 2; i >= 0; i--)
|
||||
{
|
||||
a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1];
|
||||
}
|
||||
|
||||
// Calculate strides for output tensor
|
||||
c_g_n_k_wos_strides[num_dims - 1] = 1;
|
||||
for(index_t i = num_dims - 2; i >= 0; i--)
|
||||
{
|
||||
c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1];
|
||||
}
|
||||
|
||||
const long_index_t a_element_space_size =
|
||||
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
|
||||
const long_index_t c_element_space_size =
|
||||
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
|
||||
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
|
||||
c_element_space_size * sizeof(CDataType));
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
const long_index_t element_space_size = ck_tile::max(
|
||||
a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB
|
||||
|
||||
const IndexType N = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
if(element_space_size > TwoGB)
|
||||
{
|
||||
// Minimum divisor of N to not exceed 2GB
|
||||
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
|
||||
const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB);
|
||||
|
||||
if(divisor <= static_cast<double>(N))
|
||||
{
|
||||
@@ -70,7 +86,8 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if(N % least_divisor == 0)
|
||||
{
|
||||
return N / least_divisor;
|
||||
IndexType result = N / least_divisor;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
// Not found, process one Convolution N per block
|
||||
@@ -90,9 +107,12 @@ struct TransformConvFwdToGemm
|
||||
return N;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
public:
|
||||
// Public getter methods for Split-N support
|
||||
CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
|
||||
CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; }
|
||||
|
||||
CK_TILE_HOST constexpr TransformConvFwdToGemm() {}
|
||||
|
||||
template <typename TransformConvFwdToGemmBase>
|
||||
@@ -100,6 +120,7 @@ struct TransformConvFwdToGemm
|
||||
TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base)
|
||||
: G_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.G_)},
|
||||
N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
|
||||
original_N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.original_N_)},
|
||||
Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
|
||||
Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
|
||||
Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
|
||||
@@ -168,18 +189,14 @@ struct TransformConvFwdToGemm
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
|
||||
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
#endif
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -223,18 +240,19 @@ struct TransformConvFwdToGemm
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
|
||||
// Store original N
|
||||
original_N_ = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
|
||||
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
original_N_ = N_;
|
||||
}
|
||||
#endif
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -278,18 +296,18 @@ struct TransformConvFwdToGemm
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
|
||||
// Store original N before potential splitting
|
||||
original_N_ = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
|
||||
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
N_ = original_N_;
|
||||
}
|
||||
#endif
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
@@ -1417,7 +1435,7 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
}
|
||||
|
||||
IndexType G_, N_;
|
||||
IndexType G_, N_, original_N_;
|
||||
IndexType Di_, Hi_, Wi_;
|
||||
IndexType Do_, Ho_, Wo_;
|
||||
IndexType Z_, Y_, X_;
|
||||
|
||||
Reference in New Issue
Block a user