[CK_TILE] Merge multiple fwd convolution groups into a single GEMM batch. (#3136)

* Merge fwd conv groups in CK Tile.

* Fix building CK fwd convs.

* Add number of merged groups to conv fwd kernel name.

* Get number of merged groups from conv config.

* Rename GemmConfig to ConvConfig.

* Clean-up TODOs.

* Check that number of conv groups must be divisible by the number of merged groups.

* Improve error handling in the conv fwd example.

* Fix clang-format.

* Fix group offsets.

* Fix merge problem.

* Address feedback from code review.

* Fix clang-formatting.
This commit is contained in:
Ville Pietilä
2025-12-02 15:23:32 +02:00
committed by GitHub
parent 2d3020e5b0
commit 66832861ad
4 changed files with 111 additions and 58 deletions

View File

@@ -643,8 +643,6 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
return false;
}
// TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
}
return true;

View File

@@ -28,7 +28,6 @@ namespace ck_tile {
template <typename GroupedConvTraitsType_, typename CDElementwise_>
struct GroupedConvFwdKernelArgs
{
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
GroupedConvTraitsType_::ConvSpecialization,
@@ -40,6 +39,10 @@ struct GroupedConvFwdKernelArgs
using CDElementwise = CDElementwise_;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
static_assert(!GroupedConvTraitsType_::ExplicitGemm ||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
"Explicit GEMM does not support merging convolution groups!");
template <
typename InLay = typename GroupedConvTraitsType_::InLayout,
typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
@@ -71,11 +74,6 @@ struct GroupedConvFwdKernelArgs
k_batch = args.k_batch;
// GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0];
GemmBatch = args.G_;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -100,13 +98,14 @@ struct GroupedConvFwdKernelArgs
c_grid_desc_m_n =
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.C_ * NumGroupsToMerge;
group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<index_t>());
group_stride_c = args.K_;
group_stride_c = args.K_ * NumGroupsToMerge;
// Initialize Split-N support fields for 1D convolution (NWGC layout)
// Get the actual split N from transformer
@@ -121,8 +120,20 @@ struct GroupedConvFwdKernelArgs
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0];
GemmM = a_grid_desc_m_k.get_length(number<0>{});
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
<< ", number of N splits: " << n_splits
<< ", input_batch_stride: " << input_batch_stride
<< ", output_batch_stride: " << output_batch_stride
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
}
}
template <
@@ -163,11 +174,6 @@ struct GroupedConvFwdKernelArgs
k_batch = args.k_batch;
// Note: GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
GemmBatch = args.G_;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -192,13 +198,14 @@ struct GroupedConvFwdKernelArgs
c_grid_desc_m_n =
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.C_ * NumGroupsToMerge;
group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<index_t>());
group_stride_c = args.K_;
group_stride_c = args.K_ * NumGroupsToMerge;
// Initialize Split-N support fields for 2D convolution (NHWGC layout)
// Get the actual split N from transformer
@@ -213,8 +220,20 @@ struct GroupedConvFwdKernelArgs
output_batch_stride =
args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
GemmM = a_grid_desc_m_k.get_length(number<0>{});
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
<< ", number of N splits: " << n_splits
<< ", input_batch_stride: " << input_batch_stride
<< ", output_batch_stride: " << output_batch_stride
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
}
}
template <
@@ -262,12 +281,6 @@ struct GroupedConvFwdKernelArgs
k_batch = args.k_batch;
// Note: GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
args.filter_spatial_lengths_[2];
GemmBatch = args.G_;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -292,13 +305,14 @@ struct GroupedConvFwdKernelArgs
c_grid_desc_m_n =
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.C_ * NumGroupsToMerge;
group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<index_t>());
group_stride_c = args.K_;
group_stride_c = args.K_ * NumGroupsToMerge;
// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
// Get the actual split N from transformer
@@ -313,11 +327,21 @@ struct GroupedConvFwdKernelArgs
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
args.output_spatial_lengths_[2];
}
GemmM = a_grid_desc_m_k.get_length(number<0>{});
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
<< ", number of N splits: " << n_splits
<< ", input_batch_stride: " << input_batch_stride
<< ", output_batch_stride: " << output_batch_stride
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
}
}
using AGridDescMK = remove_cvref_t<
decltype(ConvToGemmFwdTransformer{}
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
@@ -343,6 +367,7 @@ struct GroupedConvFwdKernelArgs
index_t GemmN;
index_t GemmK;
index_t GemmBatch;
index_t NumGroupsToMerge;
const void* in_ptr;
const void* wei_ptr;
@@ -567,13 +592,25 @@ struct GroupedConvolutionForwardKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
if (NumGroupsToMerge > 1) {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
"merge",
NumGroupsToMerge);
} else {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
}
// clang-format on
}
@@ -742,6 +779,16 @@ struct GroupedConvolutionForwardKernel
return false;
}
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
{
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
{
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
return false;
}
}
return true;
}

View File

@@ -470,10 +470,10 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeADescriptor_M_K() const
{
IndexType NStrideTensorA_ = Wi_ * G_ * C_;
IndexType WiStride_ = G_ * C_;
IndexType CStrideTensorA_ = 1;
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType GStrideTensorA_ = C_;
IndexType CStrideTensorA_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
@@ -701,11 +701,11 @@ struct TransformConvFwdToGemm
CK_TILE_HOST auto MakeADescriptor_M_K() const
{
IndexType NStrideTensorA_ = Hi_ * Wi_ * G_ * C_;
IndexType HiStride_ = Wi_ * G_ * C_;
IndexType WiStride_ = G_ * C_;
IndexType CStrideTensorA_ = 1;
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType GStrideTensorA_ = C_;
IndexType CStrideTensorA_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
@@ -960,12 +960,12 @@ struct TransformConvFwdToGemm
CK_TILE_HOST auto MakeADescriptor_M_K() const
{
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType DiStride_ = Hi_ * Wi_ * G_ * C_;
IndexType HiStride_ = Wi_ * G_ * C_;
IndexType WiStride_ = G_ * C_;
IndexType CStrideTensorA_ = 1;
IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
IndexType GStrideTensorA_ = C_;
IndexType CStrideTensorA_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
{
@@ -1289,9 +1289,9 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeBDescriptor_N_K() const
{
IndexType CStrideTensorB_ = 1;
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_;
IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
IndexType CStrideTensorB_ = 1;
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
{
@@ -1356,10 +1356,10 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeCDescriptor_M_N() const
{
IndexType NStrideTensorC_ = Wo_ * G_ * K_;
IndexType WoStride_ = G_ * K_;
IndexType KStrideTensorC_ = 1;
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType GStrideTensorC_ = K_;
IndexType KStrideTensorC_ = 1;
const IndexType NDoHoWo = N_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
@@ -1417,11 +1417,11 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeCDescriptor_M_N() const
{
IndexType NStrideTensorC_ = Ho_ * Wo_ * G_ * K_;
IndexType HoStride_ = Wo_ * G_ * K_;
IndexType WoStride_ = G_ * K_;
IndexType KStrideTensorC_ = 1;
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType GStrideTensorC_ = K_;
IndexType KStrideTensorC_ = 1;
const IndexType NDoHoWo = N_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)
@@ -1482,12 +1482,12 @@ struct TransformConvFwdToGemm
bool>::type = false>
CK_TILE_HOST auto MakeCDescriptor_M_N() const
{
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType DoStride_ = Ho_ * Wo_ * G_ * K_;
IndexType HoStride_ = Wo_ * G_ * K_;
IndexType WoStride_ = G_ * K_;
IndexType KStrideTensorC_ = 1;
IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
IndexType GStrideTensorC_ = K_;
IndexType KStrideTensorC_ = 1;
const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
if constexpr(NumGroupsToMerge == 1)