From 044bcfcb1e47993c44c10d172bf75b5254273eea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Thu, 16 Oct 2025 11:03:14 +0000 Subject: [PATCH] Take universal GEMM pipeline into use for grouped convolutions. --- ...ped_convolution_backward_weight_kernel.hpp | 2 +- .../gpu/gemm_configs.hpp | 91 ++++++ ...grouped_conv_bwd_weight_bf16_instances.hpp | 162 +---------- ...grouped_conv_bwd_weight_fp16_instances.hpp | 23 +- .../tile_grouped_conv_bwd_weight_invoker.hpp | 271 ++++++++++-------- .../tile_grouped_conv_instance_factory.hpp | 5 + 6 files changed, 278 insertions(+), 276 deletions(-) create mode 100644 library/include/ck_tile/library/tensor_operation_instance/gpu/gemm_configs.hpp diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 7906aa3389..874a23d670 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -466,7 +466,7 @@ struct GroupedConvolutionBackwardWeightKernel { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + CK_TILE_ERROR("Conditions not met for Kbatch > 1!"); } return false; } diff --git a/library/include/ck_tile/library/tensor_operation_instance/gpu/gemm_configs.hpp b/library/include/ck_tile/library/tensor_operation_instance/gpu/gemm_configs.hpp new file mode 100644 index 0000000000..509486569b --- /dev/null +++ b/library/include/ck_tile/library/tensor_operation_instance/gpu/gemm_configs.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 + +namespace ck_tile { +namespace ops { + +using MemoryOpSet = std::integral_constant; + +using MemoryOpAtomicAdd = std::integral_constant; + +struct GemmConfigBase +{ + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; + static constexpr bool TiledMMAPermuteN = false; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +} // namespace ops +} // namespace ck_tile diff --git a/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_bf16_instances.hpp b/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_bf16_instances.hpp index bd081f569a..7428393ab6 100644 --- a/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_bf16_instances.hpp +++ b/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_bf16_instances.hpp @@ -17,155 +17,21 @@ template using tile_grouped_conv_bwd_weight_bf16_instances = std::tuple< - // clang-format off - //#####################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| In| Wei| Out| K-block| M-tile| N-tile | K-tile | M-warp| N-warp| K-warp| M-warp| N-warp| K-warp| Vector| Vector| Vector| - //#####################################| Dim| | | | Type| Type| Type| Elementwise| Elementwise| Elementwise| per| | | | | | | tile| tile| tile| size| size| size| - //#####################################| Spatial| | | | | | | Operation| Operation| Operation| CU| | | | | | | size| size| size| A| B| C| - //#####################################| | | | | | | | | | | - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, - // GroupedConvolutionBackwardWeightInvoker, +// clang-format off + //#####################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| In| Wei| Out| K-block| M-tile| N-tile | K-tile | M-warp| N-warp| K-warp| M-warp| N-warp| K-warp| Vector| Vector| Vector| Double| GEMM| + //#####################################| Dim| | | | Type| Type| Type| Elementwise| Elementwise| Elementwise| per| | | | | | | tile| tile| tile| size| size| size| smem| pipeline| + //#####################################| Spatial| | | | | | | Operation| Operation| Operation| CU| | | | | | | size| size| size| A| B| C| buffer| version| + GroupedConvolutionBackwardWeightInvoker, + GroupedConvolutionBackwardWeightInvoker + // GroupedConvolutionBackwardWeightInvoker, + // GroupedConvolutionBackwardWeightInvoker, + // GroupedConvolutionBackwardWeightInvoker, - //#####################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| In| Wei| Out| K-block| M-tile| N-tile | K-tile | M-warp| N-warp| K-warp| M-warp| N-warp| K-warp| Vector| Vector| Vector| - //#####################################| Dim| | | | Type| Type| Type| Elementwise| Elementwise| Elementwise| per| | | | | | | tile| tile| tile| size| size| size| - //#####################################| Spatial| | | | | | | Operation| Operation| Operation| CU| | | | | | | size| size| size| A| B| C| - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker - // // clang-format on - // // clang-format on - // clang-format on - >; + // GroupedConvolutionBackwardWeightInvoker, + // GroupedConvolutionBackwardWeightInvoker, + // GroupedConvolutionBackwardWeightInvoker +// // clang-format on +>; } // namespace ops } // namespace ck_tile diff --git a/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_fp16_instances.hpp b/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_fp16_instances.hpp index 49c73eda95..e88b65e3e5 100644 --- a/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_fp16_instances.hpp +++ b/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_fp16_instances.hpp @@ -17,21 +17,14 @@ template using tile_grouped_conv_bwd_weight_f16_instances = std::tuple< - // clang-format off - //#####################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| In| Wei| Out| K-block| M-tile| N-tile | K-tile | M-warp| N-warp| K-warp| M-warp| N-warp| K-warp| Vector| Vector| Vector| - //#####################################| Dim| | | | Type| Type| Type| Elementwise| Elementwise| Elementwise| per| | | | | | | tile| tile| tile| size| size| size| - //#####################################| Spatial| | | | | | | Operation| Operation| Operation| CU| | | | | | | size| size| size| A| B| C| - //#####################################| | | | | | | | | | | - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker, - GroupedConvolutionBackwardWeightInvoker - // clang-format on - >; +// clang-format off + //#####################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| In| Wei| Out| K-block| M-tile| N-tile | K-tile | M-warp| N-warp| K-warp| M-warp| N-warp| K-warp| Vector| Vector| Vector| Double| GEMM| + //#####################################| Dim| | | | Type| Type| Type| Elementwise| Elementwise| Elementwise| per| | | | | | | tile| tile| tile| size| size| size| smem| pipeline| + //#####################################| Spatial| | | | | | | Operation| Operation| Operation| CU| | | | | | | size| size| size| A| B| C| buffer| version| | + GroupedConvolutionBackwardWeightInvoker, + GroupedConvolutionBackwardWeightInvoker +// clang-format on +>; } // namespace ops } // namespace ck_tile diff --git a/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_invoker.hpp b/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_invoker.hpp index 918ad2271b..ca32e9f0c7 100644 --- a/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_invoker.hpp +++ b/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_bwd_weight_invoker.hpp @@ -12,6 +12,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/library/tensor_operation_instance/gpu/gemm_configs.hpp" namespace ck_tile { namespace ops { @@ -31,7 +32,7 @@ template + ck_tile::index_t VectorSizeC, + bool DoubleSmemBuffer, + ck_tile::index_t PipelineVersion> struct GroupedConvolutionBackwardWeightInvoker : public GroupedConvolutionBackwardWeightBaseInvoker { - using CodegenShape_ = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, + GemmConfigBase::PermuteA, + GemmConfigBase::PermuteB>; - static constexpr auto ConvSpec_ = ck_tile::ConvolutionSpecialization::Default; + static constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner_ = ck_tile::GemmTile1DPartitioner; - using GroupedConvTraitsType_ = ck_tile::GroupedConvTraits, // = DsLayout - OutLayout, - VectorSizeA, - VectorSizeB, - VectorSizeC>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using GroupedConvTraitsType = ck_tile::GroupedConvTraits, // = DsLayout + OutLayout, + VectorSizeA, + VectorSizeB, + VectorSizeC>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GemmConfigBase::kPadM, + GemmConfigBase::kPadN, + GemmConfigBase::kPadK, + DoubleSmemBuffer, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout, + GemmConfigBase::TransposeC, + GemmConfigBase::UseStructuredSparsity, + false, // Persistent, + GemmConfigBase::NumWaveGroups>; using AccDataType = float; - using CDEElementWise = ck_tile::element_wise::PassThrough; + using GemmPipelineProblem = ck_tile::GemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + true, + VectorSizeA, + VectorSizeB>; - using CodegenPipelineProblem_ = ck_tile::GemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - CodegenShape_, - typename GroupedConvTraitsType_::GroupedConvImplicitGemmTraitsBwdWeight, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - InDataType, - true, - GroupedConvTraitsType_::VectorSizeA, - GroupedConvTraitsType_::VectorSizeB>; + using BaseGemmPipeline = typename PipelineTypeTraits::template UniversalGemmPipeline; + + template + auto CreateKernel() const + { + constexpr auto scheduler = GemmConfigBase::Scheduler; + + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; - using CodegenPipeline_ = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline; - using ConvEpilogueAtomicAdd_ = ck_tile::CShuffleEpilogue, // = DsDataType, - AccDataType, - OutDataType, - typename GroupedConvTraitsType_::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, - CDEElementWise, - TilePartitioner_::MPerBlock, - TilePartitioner_::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - CodegenPipelineProblem_::TransposeC, - ck_tile::memory_operation_enum::atomic_add, - 1, - true, - GroupedConvTraitsType_::VectorSizeC>>; + using CDEElementWise = ck_tile::element_wise::PassThrough; - using ConvEpilogueSet_ = ck_tile::CShuffleEpilogue, // = DsDataType, - AccDataType, - OutDataType, - typename GroupedConvTraitsType_::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, - CDEElementWise, - TilePartitioner_::MPerBlock, - TilePartitioner_::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - CodegenPipelineProblem_::TransposeC, - ck_tile::memory_operation_enum::set, - 1, - true, - GroupedConvTraitsType_::VectorSizeC>>; + using ConvEpilogue = ck_tile::CShuffleEpilogue, // = DsDataType + AccDataType, + WeiDataType, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + ck_tile::tensor_layout::gemm::RowMajor, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + GemmConfigBase::TransposeC, + MemOp, + 1, + true, + GroupedConvTraitsType::VectorSizeC>>; - using KernelSplitK = ck_tile::GroupedConvolutionBackwardWeightKernel; - - using KernelNonSplitK = ck_tile::GroupedConvolutionBackwardWeightKernel; + return ck_tile::GroupedConvolutionBackwardWeightKernel{}; + } bool IsSupportedArgument(const ck_tile::GroupedConvBwdWeightHostArgs& args) const override { - if (args.k_batch == 1) + if (args.k_batch > 1) { - return KernelNonSplitK::IsSupportedArgument(KernelNonSplitK::MakeKernelArgs(args)); + using Kernel = decltype(CreateKernel()); + return Kernel::IsSupportedArgument(args); } - return KernelSplitK::IsSupportedArgument(KernelSplitK::MakeKernelArgs(args)); + using Kernel = decltype(CreateKernel()); + return Kernel::IsSupportedArgument(args); }; - template - float RunImpl(const ck_tile::GroupedConvBwdWeightHostArgs& args, bool time_kernel) + float Run(const ck_tile::GroupedConvBwdWeightHostArgs& args, bool time_kernel) const override { - auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + const ck_tile::index_t gemm_k = + args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), + args.output_spatial_lengths_.end(), + 1, + std::multiplies()); - constexpr int n_warmup = 5; - constexpr int n_repeat = 50; - ck_tile::stream_config s {nullptr, time_kernel, 1, n_warmup, n_repeat}; - float avg_time = ck_tile::launch_kernel_time_mask( - s, - Kernel::Preprocess(kargs, s), - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; - return avg_time; - }; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto memory_operation = memory_operation_.value; - float Run(const ck_tile::GroupedConvBwdWeightHostArgs& args, bool time_kernel) override - { - if (args.k_batch == 1) - { - return RunImpl(args, time_kernel); - } - else - { - return RunImpl(args, time_kernel); - } + auto kernel = CreateKernel(); + using Kernel = decltype(kernel); + + auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); + + constexpr int n_warmup = 5; + constexpr int n_repeat = 50; + ck_tile::stream_config s {nullptr, time_kernel, 1, n_warmup, n_repeat}; + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(kernel, grids, blocks, 0, kargs)); + + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + } + else + { + Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; }; std::string GetName(const ck_tile::GroupedConvBwdWeightHostArgs& args) const override { std::stringstream min_occupancy; min_occupancy << "_blk_per_cu_" << kBlockPerCu; - if (args.k_batch == 1) + if (args.k_batch > 1) { - return KernelNonSplitK::GetName() + min_occupancy.str(); + using Kernel = decltype(CreateKernel()); + return Kernel::GetName() + min_occupancy.str(); } - return KernelSplitK::GetName() + min_occupancy.str(); + using Kernel = decltype(CreateKernel()); + return Kernel::GetName() + min_occupancy.str(); }; GroupedConvolutionBackwardWeightInvoker() = default; diff --git a/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_instance_factory.hpp b/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_instance_factory.hpp index 6ec251b052..e4bbc9d8cc 100644 --- a/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_instance_factory.hpp +++ b/library/include/ck_tile/library/tensor_operation_instance/gpu/tile_grouped_conv_instance_factory.hpp @@ -10,6 +10,11 @@ #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 + namespace ck_tile { namespace ops {