[CK Tile] Grouped conv fwd splitn support (#2776)

## What's New
  Add Split-N support for grouped convolution forward to handle tensors >2GB by splitting the batch dimension.

  ## Bug Fix
  Fixed 32-bit integer overflow that caused crashes with 6+ splits:
  - Use `long_index_t` for batch offset calculations
  - Remove redundant GemmM initialization in constructors

  ## How It Works
  - Automatically splits batch dimension when tensor exceeds 2GB
  - Uses grid.z dimension for parallel processing of splits
  - Each split processes a subset of batches independently

  ## Testing
  Verified with tile_example_grouped_conv_fwd:
  - n=3000 (6 splits) ✓
  - n=3500 (7 splits) ✓
  - n=10480 (40 splits) ✓
This commit is contained in:
JH-Leon-KIM-AMD
2025-09-16 16:56:11 +03:00
committed by GitHub
parent 59cb906482
commit 804065a36b
2 changed files with 135 additions and 39 deletions

View File

@@ -24,7 +24,7 @@ struct TransformConvFwdToGemm
static constexpr auto I3 = number<3>{};
static constexpr auto I4 = number<4>{};
static constexpr auto I5 = number<5>{};
#if 0 // TODO: Enable these functionalities
template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
@@ -42,24 +42,40 @@ struct TransformConvFwdToGemm
template <typename ConvDimsType>
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides)
const ConvDimsType& c_g_n_k_wos_lengths)
{
// Calculate strides internally assuming contiguous memory layout
ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides;
const index_t num_dims = a_g_n_c_wis_lengths.size();
// Calculate strides for input tensor (innermost to outermost)
a_g_n_c_wis_strides[num_dims - 1] = 1;
for(index_t i = num_dims - 2; i >= 0; i--)
{
a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1];
}
// Calculate strides for output tensor
c_g_n_k_wos_strides[num_dims - 1] = 1;
for(index_t i = num_dims - 2; i >= 0; i--)
{
c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1];
}
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
const long_index_t c_element_space_size =
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const long_index_t element_space_size = ck_tile::max(
a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB
const IndexType N = a_g_n_c_wis_lengths[I1];
if(element_space_size > TwoGB)
{
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB);
if(divisor <= static_cast<double>(N))
{
@@ -70,7 +86,8 @@ struct TransformConvFwdToGemm
{
if(N % least_divisor == 0)
{
return N / least_divisor;
IndexType result = N / least_divisor;
return result;
}
}
// Not found, process one Convolution N per block
@@ -90,9 +107,12 @@ struct TransformConvFwdToGemm
return N;
}
}
#endif
public:
// Public getter methods for Split-N support
CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; }
CK_TILE_HOST constexpr TransformConvFwdToGemm() {}
template <typename TransformConvFwdToGemmBase>
@@ -100,6 +120,7 @@ struct TransformConvFwdToGemm
TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base)
: G_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.G_)},
N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
original_N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.original_N_)},
Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
@@ -168,18 +189,14 @@ struct TransformConvFwdToGemm
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
#if 0 // TODO: Enable these functionalities
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
}
#endif
N_ = c_g_n_k_wos_lengths[I1];
}
template <typename ConvDimsType,
@@ -223,18 +240,19 @@ struct TransformConvFwdToGemm
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
#if 0 // TODO: Enable these functionalities
// Store original N
original_N_ = c_g_n_k_wos_lengths[I1];
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
N_ = c_g_n_k_wos_lengths[I1];
original_N_ = N_;
}
#endif
N_ = c_g_n_k_wos_lengths[I1];
}
template <typename ConvDimsType,
@@ -278,18 +296,18 @@ struct TransformConvFwdToGemm
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
#if 0 // TODO: Enable these functionalities
// Store original N before potential splitting
original_N_ = c_g_n_k_wos_lengths[I1];
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
N_ = original_N_;
}
#endif
N_ = c_g_n_k_wos_lengths[I1];
}
#if 0 // TODO: Enable these functionalities
@@ -1417,7 +1435,7 @@ struct TransformConvFwdToGemm
}
}
IndexType G_, N_;
IndexType G_, N_, original_N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;
IndexType Z_, Y_, X_;