[CK_BUILDER] Ck Tile Grouped convolution factory (#3352)

* [BUILDER] Ck Tile Grouped convolution factory

* Part 2

* Fixes after rebase

* Remove leftovers
This commit is contained in:
Bartłomiej Kocot
2025-12-08 10:32:56 +01:00
committed by GitHub
parent 8fec8054b2
commit 04612c30ce
55 changed files with 1431 additions and 92 deletions

View File

@@ -176,8 +176,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
return concat('_', "pipeline_AgBgCrCompV3",
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', WaveNumM, WaveNumN),
concat('x', kPadM, kPadN, kPadK));
concat('x', kPadM, kPadN, kPadK),
Problem::GetName());
// clang-format on
}

View File

@@ -301,7 +301,12 @@ struct UniversalGemmPipelineProblem
return concat('_', "gemm_problem",
concat('x', kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler);
Scheduler,
"NumWaveGroups",
NumWaveGroups,
"DoubleSmemBuffer",
DoubleSmemBuffer
);
// clang-format on
}
};

View File

@@ -560,16 +560,31 @@ struct GroupedConvolutionBackwardDataKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
return concat('_', "grouped_convolution_backward_data",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{

View File

@@ -417,26 +417,31 @@ struct GroupedConvolutionBackwardWeightKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
if (NumGroupsToMerge > 1) {
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
} else {
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(), "merge", NumGroupsToMerge);
}
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{

View File

@@ -594,26 +594,28 @@ struct GroupedConvolutionForwardKernel
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
if (NumGroupsToMerge > 1) {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
"merge",
NumGroupsToMerge);
} else {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
}
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{

View File

@@ -9,6 +9,13 @@
namespace ck_tile {
enum class GroupedConvDirection
{
FORWARD,
BACKWARD_DATA,
BACKWARD_WEIGHT
};
/// @brief The Grouped Conv kernel host arguments.
///
/// @par Overview
@@ -113,6 +120,36 @@ struct GroupedConvTraits
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
template <GroupedConvDirection Direction>
struct GemmLayouts
{
static_assert(false, "Unsupported direction.");
};
template <>
struct GemmLayouts<GroupedConvDirection::FORWARD>
{
using AsLayout = AsLayoutFwd;
using BsLayout = BsLayoutFwd;
using CLayout = CLayoutFwd;
};
template <>
struct GemmLayouts<GroupedConvDirection::BACKWARD_DATA>
{
using AsLayout = AsLayoutBwdData;
using BsLayout = BsLayoutBwdData;
using CLayout = CLayoutBwdData;
};
template <>
struct GemmLayouts<GroupedConvDirection::BACKWARD_WEIGHT>
{
using AsLayout = AsLayoutBwdWeight;
using BsLayout = BsLayoutBwdWeight;
using CLayout = CLayoutBwdWeight;
};
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;