[CK TILE] Convolution remove magic values (#3160)

* [CK TILE] Refactor Conv configs and Conv Elementwise

* fix

* [CK TILE] Convolution remove magix values

* fix partitioner
This commit is contained in:
Bartłomiej Kocot
2025-11-06 11:26:30 +01:00
committed by GitHub
parent 12922120d2
commit 2234ff830b
7 changed files with 355 additions and 352 deletions

View File

@@ -74,6 +74,21 @@ struct GroupedConvTraits
}
public:
// Fixed values for Implicit GEMM
struct FixedGemmParams
{
static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool TransposeC = false;
static constexpr bool FixedVectorSize = true;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Persistent = false;
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
};
// Compile time parameters
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr index_t NDimSpatial = NDimSpatial_;
@@ -82,31 +97,43 @@ struct GroupedConvTraits
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
// Forward Gemm Layouts
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Data Gemm Layouts
using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Weight Gemm Layouts
using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor;
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdData =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdWeight =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::ColumnMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits<true,
true,
true,
AsLayoutBwdData,
BsLayoutBwdData,
CLayoutBwdData,
NumWaveGroups>;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits<true,
true,
true,
AsLayoutBwdWeight,
BsLayoutBwdWeight,
CLayoutBwdWeight,
NumWaveGroups>;
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
static constexpr index_t NumDTensor = DsLayout::size();
static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
};