mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 04:01:25 +00:00
[CK_TILE] Conv bwd splitN support (#3047)
* Conv bwd splitN support * Adjust splitting calculations to lengths format * Prepare indexing for future splitK support
This commit is contained in:
@@ -27,7 +27,8 @@ struct GroupedConvBwdDataKernelArgs
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC>;
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
true>; // Split N enabled
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -121,6 +122,11 @@ struct GroupedConvBwdDataKernelArgs
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
++gemm_count;
|
||||
}
|
||||
group_stride_a = args.K_; // A: Out NWGK
|
||||
@@ -131,6 +137,9 @@ struct GroupedConvBwdDataKernelArgs
|
||||
std::multiplies<index_t>()); // B: Wei GKXC
|
||||
group_stride_c = args.C_; // C: In NWGC
|
||||
|
||||
input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0];
|
||||
output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0];
|
||||
|
||||
GemmBatch = args.G_;
|
||||
}
|
||||
|
||||
@@ -237,6 +246,11 @@ struct GroupedConvBwdDataKernelArgs
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
++gemm_count;
|
||||
}
|
||||
}
|
||||
@@ -248,6 +262,11 @@ struct GroupedConvBwdDataKernelArgs
|
||||
std::multiplies<index_t>()); // B: Wei GKXC
|
||||
group_stride_c = args.C_; // C: In NWGC
|
||||
|
||||
input_batch_stride =
|
||||
args.C_ * args.G_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
|
||||
output_batch_stride =
|
||||
args.K_ * args.G_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
|
||||
GemmBatch = args.G_;
|
||||
}
|
||||
|
||||
@@ -369,6 +388,11 @@ struct GroupedConvBwdDataKernelArgs
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
++gemm_count;
|
||||
}
|
||||
}
|
||||
@@ -382,6 +406,11 @@ struct GroupedConvBwdDataKernelArgs
|
||||
std::multiplies<index_t>()); // B: Wei GKXC
|
||||
group_stride_c = args.C_; // C: In NWGC
|
||||
|
||||
input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0] *
|
||||
args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2];
|
||||
output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0] *
|
||||
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
|
||||
|
||||
GemmBatch = args.G_; // C: In NWGC
|
||||
}
|
||||
|
||||
@@ -425,6 +454,13 @@ struct GroupedConvBwdDataKernelArgs
|
||||
long_index_t group_stride_a;
|
||||
long_index_t group_stride_b;
|
||||
long_index_t group_stride_c;
|
||||
|
||||
// Split-N support fields - initialize to safe defaults
|
||||
index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
|
||||
index_t n_per_split = 1; // Batches per split (N_ from transformer)
|
||||
index_t original_n = 1; // Original batch size before splitting
|
||||
index_t input_batch_stride = 0; // Stride to next batch in input tensor
|
||||
index_t output_batch_stride = 0; // Stride to next batch in output tensor
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Backward Data kernel template.
|
||||
@@ -527,7 +563,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
CK_TILE_HOST static auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
|
||||
{
|
||||
// enable batched grouped gemm
|
||||
return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch);
|
||||
return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.n_splits * kargs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
@@ -943,11 +979,31 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
|
||||
|
||||
// SplitN
|
||||
const index_t split_n_idx = __builtin_amdgcn_readfirstlane(blockIdZ / kargs.k_batch);
|
||||
const index_t split_n_offset =
|
||||
__builtin_amdgcn_readfirstlane(split_n_idx * kargs.n_per_split);
|
||||
|
||||
const long_index_t output_batch_offset =
|
||||
static_cast<long_index_t>(split_n_offset) *
|
||||
static_cast<long_index_t>(kargs.output_batch_stride);
|
||||
const long_index_t input_batch_offset = static_cast<long_index_t>(split_n_offset) *
|
||||
static_cast<long_index_t>(kargs.input_batch_stride);
|
||||
|
||||
// SplitK
|
||||
// TODO: Implement SplitK support
|
||||
// const index_t split_k_idx =
|
||||
// __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
|
||||
|
||||
// options
|
||||
// conv_bwd_data = Out * Weight = In
|
||||
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
|
||||
const OutDataType* a_ptr =
|
||||
static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a + output_batch_offset;
|
||||
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
|
||||
InDataType* c_ptr = static_cast<InDataType*>(kargs.in_ptr) + group_offset_c;
|
||||
InDataType* c_ptr =
|
||||
static_cast<InDataType*>(kargs.in_ptr) + group_offset_c + input_batch_offset;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
Reference in New Issue
Block a user