diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index b9edf247cc..8a2a60a197 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -14,19 +14,11 @@ struct ConvConfigBase { - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; - static constexpr bool kPadK = true; - - static constexpr bool TransposeC = false; - static constexpr ck_tile::index_t VectorSizeA = 4; static constexpr ck_tile::index_t VectorSizeB = 8; static constexpr ck_tile::index_t VectorSizeC = 8; - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; @@ -210,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; + static constexpr ck_tile::index_t NumWaveGroups = 2; }; template diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index 14a533ffc9..d19d3ac8ec 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -22,8 +22,6 @@ struct GroupedConvolutionBackwardDataInvoker static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, const ck_tile::stream_config& s) { - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -32,36 +30,33 @@ struct GroupedConvolutionBackwardDataInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdData, + typename GroupedConvTraitsType::BsLayoutBwdData, + typename GroupedConvTraitsType::CLayoutBwdData, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardDataInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, InDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -93,95 +89,96 @@ struct GroupedConvolutionBackwardDataInvoker const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ConvConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - auto preprocess = [&]() { - ck_tile::hip_check_error(hipMemsetAsync( - kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); - }; - - ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; + auto preprocess = [&]() { + ck_tile::hip_check_error(hipMemsetAsync( + kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index 0e777c5f8a..81b9d402ce 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -21,8 +21,6 @@ struct GroupedConvolutionBackwardWeightInvoker static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, const ck_tile::stream_config& s) { - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -31,37 +29,34 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA; - constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB; - constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker InDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -101,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -127,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker AccDataType, WeiDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -136,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index a8e41438c8..8cef2bde65 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -23,8 +23,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker { using WorkspaceDataType = float; - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -33,36 +31,34 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 4; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -70,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker InDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -102,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -128,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker AccDataType, WorkspaceDataType, // C: Workspace normally Out typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -139,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ConvConfig::K_Warp_Tile, GemmPipelineProblem::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel(Kernel{}, grids, blocks, 0, kargs), - ck_tile::make_kernel(ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(shape[1], 1), // Input Stride - ck_tile::make_tuple(shape[1], 1), // Output Stride - input_tensors, - static_cast(c_ptr))); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), + ck_tile::make_kernel( + ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr))); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 2290f60d1f..7c8269d13c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -32,8 +32,6 @@ struct GroupedConvolutionForwardInvoker { std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; } - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -42,38 +40,34 @@ struct GroupedConvolutionForwardInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - constexpr ck_tile::index_t NumGroupsToMerge = 1; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutFwd, + typename GroupedConvTraitsType::BsLayoutFwd, + typename GroupedConvTraitsType::CLayoutFwd, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -81,13 +75,14 @@ struct GroupedConvolutionForwardInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, OutDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -116,21 +111,21 @@ struct GroupedConvolutionForwardInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -142,7 +137,7 @@ struct GroupedConvolutionForwardInvoker AccDataType, OutDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -151,10 +146,10 @@ struct GroupedConvolutionForwardInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionForwardKernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 4d983baac5..9d2752727c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -25,7 +25,6 @@ struct GroupedConvolutionForwardInvoker { std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; } - constexpr int kBlockPerCu = 1; // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< @@ -35,27 +34,18 @@ struct GroupedConvolutionForwardInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; - - using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits; + using GroupedConvTraitsTypeDefault = + ck_tile::GroupedConvTraits; using GroupedConvTraitsTypeLargeTensor = ck_tile::GroupedConvTraits; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsTypeDefault::AsLayoutFwd, + typename GroupedConvTraitsTypeDefault::BsLayoutFwd, + typename GroupedConvTraitsTypeDefault::CLayoutFwd, + GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC, + GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -88,13 +83,14 @@ struct GroupedConvolutionForwardInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd, + typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, OutDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsTypeDefault::VectorSizeA, + GroupedConvTraitsTypeDefault::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -116,9 +112,9 @@ struct GroupedConvolutionForwardInvoker using TransformType = ck_tile::TransformConvFwdToGemm; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -290,7 +286,7 @@ struct GroupedConvolutionForwardInvoker AccDataType, OutDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -299,10 +295,10 @@ struct GroupedConvolutionForwardInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; // Use split-image kernel if layout supports it, otherwise use regular kernel @@ -368,7 +364,8 @@ struct GroupedConvolutionForwardInvoker } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 9b5a60ee1f..8ea6cffa7d 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -74,6 +74,21 @@ struct GroupedConvTraits } public: + // Fixed values for Implicit GEMM + struct FixedGemmParams + { + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + static constexpr bool TransposeC = false; + static constexpr bool FixedVectorSize = true; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Persistent = false; + using ELayout = ck_tile::tensor_layout::gemm::RowMajor; + }; + // Compile time parameters static constexpr bool EnableSplitImage = EnableSplitImage_; static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_; static constexpr index_t NDimSpatial = NDimSpatial_; @@ -82,31 +97,43 @@ struct GroupedConvTraits using WeiLayout = WeiLayout_; using DsLayout = DsLayout_; using OutLayout = OutLayout_; + + // Forward Gemm Layouts + using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; + using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; + // Backward Data Gemm Layouts + using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + // Backward Weight Gemm Layouts + using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor; + using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + + template using GroupedConvImplicitGemmTraitsFwd = - TileGemmTraits; - using GroupedConvImplicitGemmTraitsBwdData = - TileGemmTraits; - using GroupedConvImplicitGemmTraitsBwdWeight = - TileGemmTraits; + TileGemmTraits; + template + using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits; + template + using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits; static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_; static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_; static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_; - static constexpr index_t NumDTensor = DsLayout::size(); + static constexpr ck_tile::index_t NumDTensor = DsLayout::size(); using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); };