[CK_TILE] Conv bwd splitN support (#3047)

* Conv bwd splitN support

* Adjust splitting calculations to lengths format

* Prepare indexing for future splitK support
This commit is contained in:
Johannes Graner
2025-10-22 13:34:06 +02:00
committed by GitHub
parent 5a27a97391
commit cbd1279ae6
2 changed files with 116 additions and 31 deletions

View File

@@ -27,7 +27,8 @@ struct GroupedConvBwdDataKernelArgs
GroupedConvTraitsType_::ConvSpecialization,
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC>;
GroupedConvTraitsType_::VectorSizeC,
true>; // Split N enabled
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
static constexpr auto I0 = number<0>();
@@ -121,6 +122,11 @@ struct GroupedConvBwdDataKernelArgs
grid_size_ += grid_size_grp;
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
++gemm_count;
}
group_stride_a = args.K_; // A: Out NWGK
@@ -131,6 +137,9 @@ struct GroupedConvBwdDataKernelArgs
std::multiplies<index_t>()); // B: Wei GKXC
group_stride_c = args.C_; // C: In NWGC
input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0];
output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0];
GemmBatch = args.G_;
}
@@ -237,6 +246,11 @@ struct GroupedConvBwdDataKernelArgs
grid_size_ += grid_size_grp;
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
++gemm_count;
}
}
@@ -248,6 +262,11 @@ struct GroupedConvBwdDataKernelArgs
std::multiplies<index_t>()); // B: Wei GKXC
group_stride_c = args.C_; // C: In NWGC
input_batch_stride =
args.C_ * args.G_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
output_batch_stride =
args.K_ * args.G_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
GemmBatch = args.G_;
}
@@ -369,6 +388,11 @@ struct GroupedConvBwdDataKernelArgs
grid_size_ += grid_size_grp;
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
++gemm_count;
}
}
@@ -382,6 +406,11 @@ struct GroupedConvBwdDataKernelArgs
std::multiplies<index_t>()); // B: Wei GKXC
group_stride_c = args.C_; // C: In NWGC
input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0] *
args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2];
output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0] *
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
GemmBatch = args.G_; // C: In NWGC
}
@@ -425,6 +454,13 @@ struct GroupedConvBwdDataKernelArgs
long_index_t group_stride_a;
long_index_t group_stride_b;
long_index_t group_stride_c;
// Split-N support fields - initialize to safe defaults
index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
index_t n_per_split = 1; // Batches per split (N_ from transformer)
index_t original_n = 1; // Original batch size before splitting
index_t input_batch_stride = 0; // Stride to next batch in input tensor
index_t output_batch_stride = 0; // Stride to next batch in output tensor
};
/// @brief The Grouped Convolution Backward Data kernel template.
@@ -527,7 +563,7 @@ struct GroupedConvolutionBackwardDataKernel
CK_TILE_HOST static auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
{
// enable batched grouped gemm
return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch);
return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.n_splits * kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize()
@@ -943,11 +979,31 @@ struct GroupedConvolutionBackwardDataKernel
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
// SplitN
const index_t split_n_idx = __builtin_amdgcn_readfirstlane(blockIdZ / kargs.k_batch);
const index_t split_n_offset =
__builtin_amdgcn_readfirstlane(split_n_idx * kargs.n_per_split);
const long_index_t output_batch_offset =
static_cast<long_index_t>(split_n_offset) *
static_cast<long_index_t>(kargs.output_batch_stride);
const long_index_t input_batch_offset = static_cast<long_index_t>(split_n_offset) *
static_cast<long_index_t>(kargs.input_batch_stride);
// SplitK
// TODO: Implement SplitK support
// const index_t split_k_idx =
// __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
// options
// conv_bwd_data = Out * Weight = In
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
const OutDataType* a_ptr =
static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a + output_batch_offset;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
InDataType* c_ptr = static_cast<InDataType*>(kargs.in_ptr) + group_offset_c;
InDataType* c_ptr =
static_cast<InDataType*>(kargs.in_ptr) + group_offset_c + input_batch_offset;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];

View File

@@ -27,7 +27,7 @@ struct TransformConvBwdDataToGemm
static constexpr auto I3 = number<3>{};
static constexpr auto I4 = number<4>{};
static constexpr auto I5 = number<5>{};
#if 0 // TODO: Enable these functionalities
template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
@@ -44,25 +44,45 @@ struct TransformConvBwdDataToGemm
}
template <typename ConvDimsType>
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides)
static IndexType GetSplitedNSize(const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& a_g_n_c_wis_lengths)
{
// Calculate strides internally assuming contiguous memory layout
ConvDimsType c_g_n_k_wos_strides, a_g_n_c_wis_strides;
const index_t num_dims = c_g_n_k_wos_strides.size();
// Calculate strides for input tensor (innermost to outermost),
// Don't include outermost dimension G since it's gemm batch.
a_g_n_c_wis_strides[num_dims - 1] = 1;
for(index_t i = num_dims - 2; i >= 1; i--)
{
a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1];
}
// Calculate strides for output tensor,
// Don't include outermost dimension G since it's gemm batch.
c_g_n_k_wos_strides[num_dims - 1] = 1;
for(index_t i = num_dims - 2; i >= 1; i--)
{
c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1];
}
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
const long_index_t c_element_space_size =
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const long_index_t element_space_size = ck_tile::max(
a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
const IndexType N = a_g_n_c_wis_lengths[I1];
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const IndexType N = c_g_n_k_wos_lengths[I1];
if(element_space_size > TwoGB)
{
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB);
if(divisor <= static_cast<double>(N))
{
@@ -93,9 +113,12 @@ struct TransformConvBwdDataToGemm
return N;
}
}
#endif
public:
// Public getter methods for Split-N support
CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; }
CK_TILE_HOST constexpr TransformConvBwdDataToGemm() {}
template <typename TransformConvBwdDataToGemmBase>
@@ -103,6 +126,7 @@ struct TransformConvBwdDataToGemm
TransformConvBwdDataToGemm(const TransformConvBwdDataToGemmBase& transform_conv_to_gemm_base)
: G_{static_cast<IndexType>(transform_conv_to_gemm_base.G_)},
N_{static_cast<IndexType>(transform_conv_to_gemm_base.N_)},
original_N_{static_cast<IndexType>(transform_conv_to_gemm_base.original_N_)},
Di_{static_cast<IndexType>(transform_conv_to_gemm_base.Di_)},
Hi_{static_cast<IndexType>(transform_conv_to_gemm_base.Hi_)},
Wi_{static_cast<IndexType>(transform_conv_to_gemm_base.Wi_)},
@@ -170,17 +194,18 @@ struct TransformConvBwdDataToGemm
IdxYTilde_{I1},
IdxXTilde_{tildes[I0]}
{
#if 0 // TODO: Enable these functionalities
// Store original N
original_N_ = a_g_n_c_wis_lengths[I1];
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
N_ = a_g_n_c_wis_lengths[I1];
}
#endif
GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_);
XTilde_ = ConvStrideW_ / GcdStrideDilationW_;
@@ -229,17 +254,19 @@ struct TransformConvBwdDataToGemm
IdxYTilde_{tildes[I0]},
IdxXTilde_{tildes[I1]}
{
#if 0 // TODO: Enable these functionalities
// Store original N
original_N_ = a_g_n_c_wis_lengths[I1];
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
N_ = a_g_n_c_wis_lengths[I1];
}
#endif
GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_);
GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_);
XTilde_ = ConvStrideW_ / GcdStrideDilationW_;
@@ -291,17 +318,19 @@ struct TransformConvBwdDataToGemm
IdxYTilde_{tildes[I1]},
IdxXTilde_{tildes[I2]}
{
#if 0 // TODO: Enable these functionalities
// Store original N
original_N_ = a_g_n_c_wis_lengths[I1];
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
N_ = a_g_n_c_wis_lengths[I1];
}
#endif
GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_);
GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_);
GcdStrideDilationD_ = gcd(ConvStrideD_, ConvDilationD_);
@@ -1068,7 +1097,7 @@ struct TransformConvBwdDataToGemm
in_gemmmraw_gemmnraw_grid_desc);
}
IndexType G_, N_;
IndexType G_, N_, original_N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;
IndexType Z_, Y_, X_;