mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK TILE] Convolution remove magic values (#3160)
* [CK TILE] Refactor Conv configs and Conv Elementwise * fix * [CK TILE] Convolution remove magix values * fix partitioner
This commit is contained in:
@@ -74,6 +74,21 @@ struct GroupedConvTraits
|
||||
}
|
||||
|
||||
public:
|
||||
// Fixed values for Implicit GEMM
|
||||
struct FixedGemmParams
|
||||
{
|
||||
static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool FixedVectorSize = true;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool Persistent = false;
|
||||
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
};
|
||||
// Compile time parameters
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
@@ -82,31 +97,43 @@ struct GroupedConvTraits
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
|
||||
// Forward Gemm Layouts
|
||||
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Data Gemm Layouts
|
||||
using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Weight Gemm Layouts
|
||||
using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsFwd =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdData =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
AsLayoutBwdData,
|
||||
BsLayoutBwdData,
|
||||
CLayoutBwdData,
|
||||
NumWaveGroups>;
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
AsLayoutBwdWeight,
|
||||
BsLayoutBwdWeight,
|
||||
CLayoutBwdWeight,
|
||||
NumWaveGroups>;
|
||||
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
|
||||
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user