Merge commit '6c2ca1211ae29802281049843d284ba1bd6511f8' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-27 18:15:18 +00:00
parent 9cdbee7709
commit d3e72e87c4
32 changed files with 2051 additions and 44 deletions

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
namespace ck_tile {
@@ -14,7 +15,7 @@ namespace ck_tile {
/// This structure is passed to Grouped Convolution Kernels when creating kernel
/// arguments object. It contain all necessary information required to
/// build proper kernel argument and launch kernel on GPU.
template <typename InPtr, typename WeiPtr, typename OutPtr>
template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
struct GroupedConvHostArgs : public conv::ConvParam
{
CK_TILE_HOST GroupedConvHostArgs() = delete;
@@ -23,13 +24,15 @@ struct GroupedConvHostArgs : public conv::ConvParam
WeiPtr wei_ptr_,
const std::vector<const void*> ds_ptr_,
OutPtr out_ptr_,
index_t k_batch_)
index_t k_batch_,
CDElementwise elfunc_ = CDElementwise{})
: conv::ConvParam(conv_param),
in_ptr(in_ptr_),
wei_ptr(wei_ptr_),
ds_ptr(ds_ptr_),
out_ptr(out_ptr_),
k_batch(k_batch_)
k_batch(k_batch_),
elfunc(elfunc_)
{
}
@@ -38,11 +41,17 @@ struct GroupedConvHostArgs : public conv::ConvParam
const std::vector<const void*> ds_ptr;
OutPtr out_ptr;
index_t k_batch;
const CDElementwise elfunc;
};
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs<const void*, void*, const void*>;
using GroupedConvBwdDataHostArgs = GroupedConvHostArgs<void*, const void*, const void*>;
using PassThrough = ck_tile::element_wise::PassThrough;
template <typename CDElementwise = PassThrough>
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*, CDElementwise>;
using GroupedConvBwdWeightHostArgs =
GroupedConvHostArgs<const void*, void*, const void*, PassThrough>;
using GroupedConvBwdDataHostArgs =
GroupedConvHostArgs<void*, const void*, const void*, PassThrough>;
template <index_t NDimSpatial_,
ConvolutionSpecialization ConvSpecialization_,
@@ -50,9 +59,10 @@ template <index_t NDimSpatial_,
typename WeiLayout_,
typename DsLayout_,
typename OutLayout_,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1>
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
typename CDElementwise_ = PassThrough>
struct GroupedConvTraits
{
private:
@@ -70,6 +80,7 @@ struct GroupedConvTraits
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
using CDElementwise = CDElementwise_;
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true,
true,