[CK_BUILDER] Ck Tile Grouped convolution factory (#3352)

* [BUILDER] Ck Tile Grouped convolution factory

* Part 2

* Fixes after rebase

* Remove leftovers
This commit is contained in:
Bartłomiej Kocot
2025-12-08 10:32:56 +01:00
committed by GitHub
parent 8fec8054b2
commit 04612c30ce
55 changed files with 1431 additions and 92 deletions

View File

@@ -9,6 +9,13 @@
namespace ck_tile {
enum class GroupedConvDirection
{
FORWARD,
BACKWARD_DATA,
BACKWARD_WEIGHT
};
/// @brief The Grouped Conv kernel host arguments.
///
/// @par Overview
@@ -113,6 +120,36 @@ struct GroupedConvTraits
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
template <GroupedConvDirection Direction>
struct GemmLayouts
{
static_assert(false, "Unsupported direction.");
};
template <>
struct GemmLayouts<GroupedConvDirection::FORWARD>
{
using AsLayout = AsLayoutFwd;
using BsLayout = BsLayoutFwd;
using CLayout = CLayoutFwd;
};
template <>
struct GemmLayouts<GroupedConvDirection::BACKWARD_DATA>
{
using AsLayout = AsLayoutBwdData;
using BsLayout = BsLayoutBwdData;
using CLayout = CLayoutBwdData;
};
template <>
struct GemmLayouts<GroupedConvDirection::BACKWARD_WEIGHT>
{
using AsLayout = AsLayoutBwdWeight;
using BsLayout = BsLayoutBwdWeight;
using CLayout = CLayoutBwdWeight;
};
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;