diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp index 4950e002ac..6b37575ef1 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp @@ -21,8 +21,6 @@ template , typename DsLayout = ck_tile::tuple<>, typename CDEElementWise = ck_tile::element_wise::PassThrough> @@ -66,9 +64,7 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, WeiLayout, DsLayout, OutLayout, - NumGroupsToMerge, - MPerGroup, - NPerGroup>; + NumGroupsToMerge>; using CodegenPipelineProblem = ck_tile::GemmPipelineProblem int run_grouped_conv_bwd_weight_example_prec_type( @@ -185,9 +179,7 @@ int run_grouped_conv_bwd_weight_example_prec_type( InPrecType, WeiPrecType, OutPrecType, - NumGroupsToMerge, - MPerGroup, - NPerGroup>( + NumGroupsToMerge>( argc, argv, NWGC{}, GKXC{}, NWGK{}); } else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") @@ -196,9 +188,7 @@ int run_grouped_conv_bwd_weight_example_prec_type( InPrecType, WeiPrecType, OutPrecType, - NumGroupsToMerge, - MPerGroup, - NPerGroup>( + NumGroupsToMerge>( argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); } else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") @@ -207,9 +197,7 @@ int run_grouped_conv_bwd_weight_example_prec_type( InPrecType, WeiPrecType, OutPrecType, - NumGroupsToMerge, - MPerGroup, - NPerGroup>( + NumGroupsToMerge>( argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); } else @@ -223,8 +211,6 @@ int run(const std::string& in_layout, const std::string& wei_layout, const std::string& out_layout, int num_groups_to_merge, - int m_per_group, - int n_per_group, int argc, char* argv[]) { @@ -234,47 +220,7 @@ int run(const std::string& in_layout, } else if (num_groups_to_merge == 8) { - if (m_per_group == 1 && n_per_group == 16) - { - return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); - } - else if (m_per_group == 2 && n_per_group == 16) - { - return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); - } - else if (m_per_group == 4 && n_per_group == 16) - { - return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); - } - else if (m_per_group == 1 && n_per_group == 4) - { - return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); - } - else if (m_per_group == 2 && n_per_group == 4) - { - return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); - } - else if (m_per_group == 4 && n_per_group == 4) - { - return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); - } - else if (m_per_group == 1 && n_per_group == 8) - { - return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); - } - else if (m_per_group == 2 && n_per_group == 8) - { - return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); - } - else if (m_per_group == 4 && n_per_group == 8) - { - return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); - } - else - { - throw std::runtime_error("Unsupported MPerGroup and NPerGroup combination for NumGroupsToMerge=8! Supported combinations are (1,16), (2,8), (4,4), (8,2), (16,1)."); - } - + return run_grouped_conv_bwd_weight_example_prec_type(in_layout, wei_layout, out_layout, argc, argv); } else { @@ -293,18 +239,16 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[]) std::string wei_layout = arg_parser.get_str("wei_layout"); std::string out_layout = arg_parser.get_str("out_layout"); ck_tile::index_t num_groups_to_merge = arg_parser.get_int("num_groups_to_merge"); - ck_tile::index_t m_per_group = arg_parser.get_int("m_per_group"); - ck_tile::index_t n_per_group = arg_parser.get_int("n_per_group"); if(data_type == "fp16") { return run( - in_layout, wei_layout, out_layout, num_groups_to_merge, m_per_group, n_per_group, argc, argv); + in_layout, wei_layout, out_layout, num_groups_to_merge, argc, argv); } else if(data_type == "bf16") { return run( - in_layout, wei_layout, out_layout, num_groups_to_merge, m_per_group, n_per_group, argc, argv); + in_layout, wei_layout, out_layout, num_groups_to_merge, argc, argv); } else { diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index 522c60c1df..0de0c92d62 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -122,9 +122,7 @@ auto create_args(int argc, char* argv[]) .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("num_groups_to_merge", "1", "Number of groups to merge for grouped convolution") - .insert("m_per_group", "0", "Number of elements per grouped block in M-dimension") - .insert("n_per_group", "0", "Number of elements per grouped block in N-dimension"); + .insert("num_groups_to_merge", "1", "Number of groups to merge for grouped convolution"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc index 829c811b52..e09ed421ff 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc @@ -10,9 +10,7 @@ template + ck_tile::index_t NumGroupsToMerge = 1> float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args, int n_warmup, int n_repeat) @@ -25,9 +23,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args InLayout, WeiLayout, OutLayout, - NumGroupsToMerge, - MPerGroup, - NPerGroup>( + NumGroupsToMerge>( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = args.GetFlops(); @@ -46,8 +42,6 @@ template @@ -153,9 +147,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts( InLayout, WeiLayout, OutLayout, - NumGroupsToMerge, - MPerGroup, - NPerGroup>(args, n_warmup, n_repeat); + NumGroupsToMerge>(args, n_warmup, n_repeat); weight_dev_buf.FromDevice(weight.data()); bool pass = true; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index c7a86ad7d4..9f42525556 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -219,13 +219,6 @@ struct CShuffleEpilogue template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() { - if constexpr(NumGroupsToMerge > 1) - { - // We haven't yet tested the ColumnMajor case. - static_assert(std::is_same_v, - "Currently, the CShuffle Epilogue with NumGroupsToMerge > 1 only supports the Row Major layout"); - } - // N is contiguous dimension if constexpr(std::is_same_v) { @@ -264,33 +257,17 @@ struct CShuffleEpilogue CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - if constexpr(NumGroupsToMerge > 1) - { - return kMPerBlock * kNPerBlock * sizeof(ODataType); - } - else - { - return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); - } + return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); } - template - //template + template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* p_smem) { - if constexpr (NumGroupsToMerge == 1) - { - return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); - } - else - { - return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); - //return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); - } + return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); } template diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index e0faf89ad0..b0d5a1f162 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -620,45 +620,9 @@ struct GroupedConvolutionBackwardWeightKernel return false; } - if (GroupedConvTraitsType_::NumGroupsToMerge > 1) - { - const index_t ZYX = kargs.ZYX; - const index_t MPerGroup = ConvK; - const index_t NPerGroup = ZYX * ConvC; - - // TODO: Fix this check - if (GroupedConvTraitsType_::MPerGroup != MPerGroup) - { - CK_TILE_ERROR("MPerGroup must be equal to Conv K!"); - return false; - } - - // TODO: Fix this check - if (GroupedConvTraitsType_::NPerGroup != NPerGroup) - { - CK_TILE_ERROR("NPerGroup must be equal to Conv C * ZYX!"); - return false; - } - - // TODO: Remove this check when ConvC > 1 is implemented. - if (ConvC > 1) - { - CK_TILE_ERROR("Only Conv C == 1 is supported!"); - return false; - } - - if (kargs.NumGroupsPerBatch * ConvC * ZYX > TilePartitioner_::NPerBlock) - { - CK_TILE_ERROR("NumGroupsPerBatch * Conv C * ZYX must be less or equal to NPerBlock!"); - return false; - } - - if (kargs.NumGroupsPerBatch * ConvK > TilePartitioner_::MPerBlock) - { - CK_TILE_ERROR("NumGroupsToMerge * Conv K must be less or equal to MPerBlock!"); - return false; - } - } + // TODO: Should we enforce + // - ConvG % NumGroupsToMerge == 0? + // - ConvK % NumGroupsToMerge == 0? return true; } @@ -845,8 +809,6 @@ struct GroupedConvolutionBackwardWeightKernel auto& c_block_window = gemm_tile_windows.at(I3); EpiloguePipeline{}.template operator()< - GroupedConvTraitsType_::MPerGroup, - GroupedConvTraitsType_::NPerGroup, decltype(c_block_window), decltype(c_block_tile)>( c_block_window, c_block_tile, d_block_window, smem_ptr_0); diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 699957822a..c7de32dab5 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -50,9 +50,7 @@ template + index_t NumGroupsToMerge_ = 1> struct GroupedConvTraits { private: @@ -64,8 +62,6 @@ struct GroupedConvTraits public: static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_; - static constexpr index_t MPerGroup = MPerGroup_; - static constexpr index_t NPerGroup = NPerGroup_; static constexpr index_t NDimSpatial = NDimSpatial_; static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_; using InLayout = InLayout_;