[CK TILE] Refactor Conv configs and Conv Elementwise (#3151)

* [CK TILE] Refactor Conv configs and Conv Elementwise

* fix
This commit is contained in:
Bartłomiej Kocot
2025-11-04 15:04:53 +01:00
committed by GitHub
parent 99f38e4d9b
commit 8681ced962
14 changed files with 230 additions and 219 deletions

View File

@@ -19,7 +19,7 @@
namespace ck_tile {
/// @brief The Grouped Convolution kernel device arguments.
template <typename GroupedConvTraitsType_>
template <typename GroupedConvTraitsType_, typename CDElementwise_>
struct GroupedConvFwdKernelArgs
{
@@ -31,7 +31,7 @@ struct GroupedConvFwdKernelArgs
GroupedConvTraitsType_::VectorSizeC,
GroupedConvTraitsType_::NumGroupsToMerge,
true>; // Split N enabled
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
using CDElementwise = CDElementwise_;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
template <
@@ -469,7 +469,8 @@ struct GroupedConvolutionForwardKernel
using CDElementwise = typename EpiloguePipeline::CDElementwise;
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType_>;
using GroupedConvFwdKernelArgsSpecialized =
GroupedConvFwdKernelArgs<GroupedConvTraitsType_, CDElementwise>;
static constexpr bool IsSplitKSupported = false;

View File

@@ -63,7 +63,6 @@ template <index_t NDimSpatial_,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
index_t NumGroupsToMerge_ = 1,
typename CDElementwise_ = PassThrough,
bool EnableSplitImage_ = false>
struct GroupedConvTraits
{
@@ -83,7 +82,6 @@ struct GroupedConvTraits
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
using CDElementwise = CDElementwise_;
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true,
true,