Create invoker for the kernel and a factory for creating invokers.

This commit is contained in:
Ville Pietilä
2025-10-13 15:22:50 +00:00
parent a60dab521e
commit fc6a9e3931
3 changed files with 164 additions and 31 deletions

View File

@@ -294,26 +294,6 @@ struct GroupedConvBwdWeightKernelArgs
long_index_t group_stride_c;
};
template <ck_tile::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct GroupedConvolutionBackwardWeightInvoker
{
virtual bool IsSupportedArgument(const ck_tile::GroupedConvBwdWeightHostArgs& args) const = 0;
virtual float Run(const ck_tile::GroupedConvBwdWeightHostArgs& args, bool time_kernel) = 0;
virtual std::string GetName() const = 0;
virtual ~GroupedConvolutionBackwardWeightInvoker() = default;
};
/// @brief The Grouped Convolution Backward Weight kernel template.
///
/// @paragraph Overview Overview

View File

@@ -13,6 +13,11 @@
// #include "ck_tile/ops/common/element_wise_operation.hpp"
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
namespace ck_tile {
namespace ops {
@@ -20,6 +25,162 @@ namespace ops {
template <typename DeviceOp>
struct DeviceOperationInstanceFactory;
template <ck_tile::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct GroupedConvolutionBackwardWeightBaseInvoker
{
virtual bool IsSupportedArgument(const ck_tile::GroupedConvBwdWeightHostArgs& args) const = 0;
virtual float Run(const ck_tile::GroupedConvBwdWeightHostArgs& args, bool time_kernel) = 0;
virtual std::string GetName() const = 0;
virtual ~GroupedConvolutionBackwardWeightBaseInvoker() = default;
};
template <
ck_tile::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
int kBlockPerCu,
ck_tile::index_t M_Tile,
ck_tile::index_t N_Tile,
ck_tile::index_t K_Tile,
ck_tile::index_t M_Warp,
ck_tile::index_t N_Warp,
ck_tile::index_t K_Warp,
ck_tile::index_t M_Warp_Tile,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile,
ck_tile::index_t VectorSizeA,
ck_tile::index_t VectorSizeB,
ck_tile::index_t VectorSizeC,
bool UseSplitK>
struct GroupedConvolutionBackwardWeightInvoker :
public GroupedConvolutionBackwardWeightBaseInvoker<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using CodegenShape_ =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
static constexpr auto ConvSpec_ = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner_ = ck_tile::GemmTile1DPartitioner<CodegenShape_>;
using GroupedConvTraitsType_ = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec_,
InLayout,
WeiLayout,
OutLayout, // = DsLayout
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
using AccDataType = float;
using DsDataType = OutDataType;
using CDEElementWise = ck_tile::element_wise::PassThrough;
using CodegenPipelineProblem_ = ck_tile::GemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
CodegenShape_,
typename GroupedConvTraitsType_::GroupedConvImplicitGemmTraitsBwdWeight,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB>;
using CodegenPipeline_ = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem_>;
using MemOp = std::conditional_t<UseSplitK,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>>;
using ConvEpilogue_ = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType_::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner_::MPerBlock,
TilePartitioner_::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem_::TransposeC,
MemOp{}.value,
1,
true,
GroupedConvTraitsType_::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
TilePartitioner_,
CodegenPipeline_,
ConvEpilogue_>;
bool IsSupportedArgument(const ck_tile::GroupedConvBwdWeightHostArgs& args) const override
{
return Kernel::IsSupportedArgument(Kernel::MakeKernelArgs(args));
};
float Run(const ck_tile::GroupedConvBwdWeightHostArgs& args, bool time_kernel) override
{
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
constexpr int n_warmup = 5;
constexpr int n_repeat = 50;
ck_tile::stream_config s {nullptr, time_kernel, 1, n_warmup, n_repeat};
float avg_time = ck_tile::launch_kernel_time_mask(
s,
Kernel::Preprocess(kargs, s),
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return avg_time;
};
std::string GetName() const override
{
return Kernel::GetName();
};
~GroupedConvolutionBackwardWeightInvoker() override = default;
};
template <ck_tile::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
@@ -29,7 +190,7 @@ template <ck_tile::index_t NumDimSpatial,
typename OutDataType,
typename ComputeTypeA,
typename ComputeTypeB>
struct DeviceOperationInstanceFactory<ck_tile::GroupedConvolutionBackwardWeightInvoker<
struct DeviceOperationInstanceFactory<GroupedConvolutionBackwardWeightBaseInvoker<
NumDimSpatial,
InLayout,
WeiLayout,
@@ -43,7 +204,7 @@ struct DeviceOperationInstanceFactory<ck_tile::GroupedConvolutionBackwardWeightI
ComputeTypeA,
ComputeTypeB>>
{
using DeviceOp = GroupedConvolutionBackwardWeightInvoker<NumDimSpatial,
using DeviceOp = GroupedConvolutionBackwardWeightBaseInvoker<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,

View File

@@ -105,7 +105,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
weight_dev_buf.SetZero();
output_dev_buf.ToDevice(output.data());
using DeviceOp = ck_tile::GroupedConvolutionBackwardWeightInvoker<
using DeviceOp = ops::GroupedConvolutionBackwardWeightBaseInvoker<
NDimSpatial,
InLayout,
WeiLayout,
@@ -178,14 +178,6 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::string op_name = op->GetName();
// constexpr int kBlockPerCu = 1;
// constexpr int n_warmup = 5;
// constexpr int n_repeat = 50;
// ck_tile::stream_config s {nullptr, time_kernel, 1, n_warmup, n_repeat};
// float avg_time = ck_tile::launch_kernel_time_mask(
// s,
// Kernel::Preprocess(kargs, s),
// ck_tile::make_kernel<kBlockPerCu>(*op, grids, blocks, 0, kargs));
float avg_time = op->Run(args, time_kernel);
std::size_t flop = conv_param.GetFlops();