Merge commit '99f38e4d9bedcf1b09d58653c354f042f8c509ae' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-04 00:35:23 +00:00
parent a0410f0a05
commit 58d420c0a4
5 changed files with 161 additions and 167 deletions

View File

@@ -434,14 +434,13 @@ struct GroupedConvFwdKernelArgs
/// multiplication implementation. It is responsible for storing
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
/// the output C tensor in global memory.
template <bool EnableSplitImage_,
typename GroupedConvTraitsType_,
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct GroupedConvolutionForwardKernel
{
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;

View File

@@ -63,7 +63,8 @@ template <index_t NDimSpatial_,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
index_t NumGroupsToMerge_ = 1,
typename CDElementwise_ = PassThrough>
typename CDElementwise_ = PassThrough,
bool EnableSplitImage_ = false>
struct GroupedConvTraits
{
private:
@@ -74,6 +75,7 @@ struct GroupedConvTraits
}
public:
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;