[CK_BUILDER] Ck Tile Grouped convolution factory (#3352)

* [BUILDER] Ck Tile Grouped convolution factory

* Part 2

* Fixes after rebase

* Remove leftovers
This commit is contained in:
Bartłomiej Kocot
2025-12-08 10:32:56 +01:00
committed by GitHub
parent 8fec8054b2
commit 04612c30ce
55 changed files with 1431 additions and 92 deletions

View File

@@ -560,16 +560,31 @@ struct GroupedConvolutionBackwardDataKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
return concat('_', "grouped_convolution_backward_data",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{

View File

@@ -417,26 +417,31 @@ struct GroupedConvolutionBackwardWeightKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
if (NumGroupsToMerge > 1) {
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
} else {
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(), "merge", NumGroupsToMerge);
}
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{

View File

@@ -594,26 +594,28 @@ struct GroupedConvolutionForwardKernel
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
if (NumGroupsToMerge > 1) {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
"merge",
NumGroupsToMerge);
} else {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
}
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{