Merge commit '9fcc1ee9fd9730efd865f530afde505f2556954d' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-18 17:12:50 +00:00
parent d436787ed0
commit 68b20e1d4f
113 changed files with 610 additions and 531 deletions

View File

@@ -354,7 +354,7 @@ struct GroupedConvolutionBackwardWeightKernel
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
@@ -393,7 +393,7 @@ struct GroupedConvolutionBackwardWeightKernel
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)

View File

@@ -361,7 +361,7 @@ struct GroupedConvolutionForwardKernel
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
@@ -398,7 +398,7 @@ struct GroupedConvolutionForwardKernel
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)