[CK_TILE] Merge multiple fwd convolution groups into a single GEMM batch. (#3136)

* Merge fwd conv groups in CK Tile.

* Fix building CK fwd convs.

* Add number of merged groups to conv fwd kernel name.

* Get number of merged groups from conv config.

* Rename GemmConfig to ConvConfig.

* Clean-up TODOs.

* Check that number of conv groups must be divisible by the number of merged groups.

* Improve error handling in the conv fwd example.

* Fix clang-format.

* Fix group offsets.

* Fix merge problem.

* Address feedback from code review.

* Fix clang-formatting.
This commit is contained in:
Ville Pietilä
2025-12-02 15:23:32 +02:00
committed by GitHub
parent 2d3020e5b0
commit 66832861ad
4 changed files with 111 additions and 58 deletions

View File

@@ -643,8 +643,6 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
return false;
}
// TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
}
return true;

View File

@@ -28,7 +28,6 @@ namespace ck_tile {
template <typename GroupedConvTraitsType_, typename CDElementwise_>
struct GroupedConvFwdKernelArgs
{
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
GroupedConvTraitsType_::ConvSpecialization,
@@ -40,6 +39,10 @@ struct GroupedConvFwdKernelArgs
using CDElementwise = CDElementwise_;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
static_assert(!GroupedConvTraitsType_::ExplicitGemm ||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
"Explicit GEMM does not support merging convolution groups!");
template <
typename InLay = typename GroupedConvTraitsType_::InLayout,
typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
@@ -71,11 +74,6 @@ struct GroupedConvFwdKernelArgs
k_batch = args.k_batch;
// GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0];
GemmBatch = args.G_;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -100,13 +98,14 @@ struct GroupedConvFwdKernelArgs
c_grid_desc_m_n =
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.C_ * NumGroupsToMerge;
group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<index_t>());
group_stride_c = args.K_;
group_stride_c = args.K_ * NumGroupsToMerge;
// Initialize Split-N support fields for 1D convolution (NWGC layout)
// Get the actual split N from transformer
@@ -121,8 +120,20 @@ struct GroupedConvFwdKernelArgs
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0];
GemmM = a_grid_desc_m_k.get_length(number<0>{});
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
<< ", number of N splits: " << n_splits
<< ", input_batch_stride: " << input_batch_stride
<< ", output_batch_stride: " << output_batch_stride
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
}
}
template <
@@ -163,11 +174,6 @@ struct GroupedConvFwdKernelArgs
k_batch = args.k_batch;
// Note: GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
GemmBatch = args.G_;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -192,13 +198,14 @@ struct GroupedConvFwdKernelArgs
c_grid_desc_m_n =
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.C_ * NumGroupsToMerge;
group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<index_t>());
group_stride_c = args.K_;
group_stride_c = args.K_ * NumGroupsToMerge;
// Initialize Split-N support fields for 2D convolution (NHWGC layout)
// Get the actual split N from transformer
@@ -213,8 +220,20 @@ struct GroupedConvFwdKernelArgs
output_batch_stride =
args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
GemmM = a_grid_desc_m_k.get_length(number<0>{});
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
<< ", number of N splits: " << n_splits
<< ", input_batch_stride: " << input_batch_stride
<< ", output_batch_stride: " << output_batch_stride
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
}
}
template <
@@ -262,12 +281,6 @@ struct GroupedConvFwdKernelArgs
k_batch = args.k_batch;
// Note: GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
args.filter_spatial_lengths_[2];
GemmBatch = args.G_;
in_ptr = args.in_ptr;
wei_ptr = args.wei_ptr;
for(index_t d = 0; d < NumDTensor; d++)
@@ -292,13 +305,14 @@ struct GroupedConvFwdKernelArgs
c_grid_desc_m_n =
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.C_ * NumGroupsToMerge;
group_stride_b = args.K_ * args.C_ * NumGroupsToMerge *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<index_t>());
group_stride_c = args.K_;
group_stride_c = args.K_ * NumGroupsToMerge;
// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
// Get the actual split N from transformer
@@ -313,11 +327,21 @@ struct GroupedConvFwdKernelArgs
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
args.output_spatial_lengths_[2];
}
GemmM = a_grid_desc_m_k.get_length(number<0>{});
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split
<< ", number of N splits: " << n_splits
<< ", input_batch_stride: " << input_batch_stride
<< ", output_batch_stride: " << output_batch_stride
<< ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl;
}
}
using AGridDescMK = remove_cvref_t<
decltype(ConvToGemmFwdTransformer{}
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
@@ -343,6 +367,7 @@ struct GroupedConvFwdKernelArgs
index_t GemmN;
index_t GemmK;
index_t GemmBatch;
index_t NumGroupsToMerge;
const void* in_ptr;
const void* wei_ptr;
@@ -567,13 +592,25 @@ struct GroupedConvolutionForwardKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
if (NumGroupsToMerge > 1) {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
"merge",
NumGroupsToMerge);
} else {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
}
// clang-format on
}
@@ -742,6 +779,16 @@ struct GroupedConvolutionForwardKernel
return false;
}
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
{
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
{
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
return false;
}
}
return true;
}