mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK TILE] Refactor grouped conv fwd large tensor (#3144)
This commit is contained in:
@@ -434,14 +434,13 @@ struct GroupedConvFwdKernelArgs
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <bool EnableSplitImage_,
|
||||
typename GroupedConvTraitsType_,
|
||||
template <typename GroupedConvTraitsType_,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
|
||||
@@ -63,7 +63,8 @@ template <index_t NDimSpatial_,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
index_t NumGroupsToMerge_ = 1,
|
||||
typename CDElementwise_ = PassThrough>
|
||||
typename CDElementwise_ = PassThrough,
|
||||
bool EnableSplitImage_ = false>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -74,6 +75,7 @@ struct GroupedConvTraits
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
|
||||
Reference in New Issue
Block a user