diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index c508126adb..2b09ba0b1f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -69,7 +69,7 @@ template concept FwdXdlV3Algorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesBlockGemm; + SpecifiesGemmSpecialization && SpecifiesBlockGemm && SpecifiesNumGroupsToMerge; // FWD WMMA algorithm concepts template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 14266ad63f..7ea9938ea4 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -161,7 +161,8 @@ struct ConvFwdXdlV3Factory BLOCK_GEMM.pipeline_version, typename Types::InComputeType, typename Types::WeiComputeType, - IS_DIRECT_LOAD>; + IS_DIRECT_LOAD, + ALGORITHM.num_conv_groups_to_merge>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 4549b76a3f..98a304f13a 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -71,7 +71,8 @@ template + bool DirectLoad, + index_t NumGroupsToMerge> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; } // namespace ck::tensor_operation::device @@ -132,7 +133,8 @@ template + bool DirectLoad, + index_t NumGroupsToMerge> struct InstanceTraits> + DirectLoad, + NumGroupsToMerge>> { /// @brief Tag type identifying this device kernel variant using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag; @@ -270,6 +273,8 @@ struct InstanceTraits(); // 47. AComputeDataType oss << "," << detail::type_name(); // 48. BComputeDataType oss << "," << (DirectLoad ? "true" : "false"); // 49. DirectLoad + oss << "," << kNumGroupsToMerge; // 50. NumGroupsToMerge oss << ">"; return oss.str(); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp index 8d85370b26..501bc8d9d8 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp @@ -35,7 +35,8 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v2_intrawave); + .with_block_gemm(BlockGemmDesc_v2_intrawave) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index 610e2fad5f..d7c37ea3e3 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -31,7 +31,8 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v1_intrawave); + .with_block_gemm(BlockGemmDesc_v1_intrawave) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; @@ -69,7 +70,8 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v5_intrawave); + .with_block_gemm(BlockGemmDesc_v5_intrawave) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 51bc45c29b..135ad193db 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -32,7 +32,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd .with_transfer(cku::Transfer_4x64x1) .with_fwd_specializations(ckb::ConvSpecialization::DEFAULT, ckb::GemmSpecialization::MNKPadding) - .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); + .with_block_gemm(cku::BlockGemmDesc_v3_intrawave) + .with_num_conv_groups_to_merge(1); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index 9e6ca00e58..7311020007 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -31,7 +31,8 @@ TEST(FwdConvInstances, .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v3_intrawave); + .with_block_gemm(BlockGemmDesc_v3_intrawave) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index 56d4b8be59..8aba066d70 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -32,7 +32,8 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v4_intrawave); + .with_block_gemm(BlockGemmDesc_v4_intrawave) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index df8339241b..f5779bf5ae 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -32,7 +32,8 @@ TEST(FwdConvInstances, .with_transfer(Transfer_4x64x1) .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, GemmSpecialization::MNKPadding) - .with_block_gemm(BlockGemmDesc_v1_intrawave); + .with_block_gemm(BlockGemmDesc_v1_intrawave) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 7de7fae92d..ba9cb0a030 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -1188,7 +1188,8 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer ck::half_t, // AComputeDataType ck::half_t, // BComputeDataType - false>; // DirectLoad + false, // DirectLoad + 1>; // NumGroupsToMerge // Use ConvTraitsTmpl to extract compile-time information const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp index 72269c38ac..aad6da7f6e 100644 --- a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp @@ -97,6 +97,7 @@ static constexpr int kCShuffleMXdlPerWavePerShuffle = 1; static constexpr int kCShuffleNXdlPerWavePerShuffle = 1; static constexpr int kCDEBlockTransferScalarPerVector_NPerBlock = 8; static constexpr bool kDirectLoad = false; +static constexpr int kNumGroupsToMerge = 1; using DefaultABlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>; using DefaultABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>; @@ -176,7 +177,8 @@ using DeviceInstanceForTests_V3 = BlkGemmPipelineVer, ADataType, BDataType, - defaults::kDirectLoad>; + defaults::kDirectLoad, + defaults::kNumGroupsToMerge>; // Test case helper for specialization testing template diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp index 38942f9d45..fa4bc73bd2 100644 --- a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp @@ -102,7 +102,8 @@ TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3) ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer ck::half_t, // AComputeDataType ck::half_t, // BComputeDataType - false>; // DirectLoad + false, // DirectLoad + 1>; // NumGroupsToMerge using InstTraits = ck_tile::reflect::InstanceTraits; const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index f5b9bdc3b5..c4e83293ef 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -566,7 +566,8 @@ using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = FwdXdlGemm_, Transfer_<>, ConvSpecializationFwd_, - BlockGemm_>; + BlockGemm_, + GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = ConvAlgorithmTemplate); diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index ad0a2cadc6..f10a29722d 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -83,7 +83,8 @@ TEST(InstanceTraits, V3ExtractsAllFieldsCorrectly) ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer ck::half_t, // AComputeDataType ck::half_t, // BComputeDataType - false>; + false, // DirectLoad + 1>; // NumGroupsToMerge // Use InstanceTraits to extract compile-time information using Traits = ck_tile::reflect::InstanceTraits; @@ -225,7 +226,8 @@ TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat) ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer ck::half_t, // AComputeDataType ck::half_t, // BComputeDataType - false>; // DirectLoad + false, // DirectLoad + 1>; // NumGroupsToMerge std::string instance_str = ck_tile::reflect::instance_string(); @@ -278,7 +280,8 @@ TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat) ",v1" // BlkGemmPipelineVer ",fp16" // AComputeDataType ",fp16" // BComputeDataType - ",false>"; // DirectLoad + ",false" // DirectLoad + ",1>"; // NumGroupsToMerge EXPECT_EQ(instance_str, expected_str); } diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp index ccfa4c7197..5a0dcbeaf5 100644 --- a/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_v3.cpp @@ -77,7 +77,8 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" ",v4" // BlkGemmPipelineVer ",fp16" // AComputeDataType ",fp16" // BComputeDataType - ",false>"; // DirectLoad + ",false" // DirectLoad + ",1>"; // NumGroupsToMerge // Test describe() through base class pointer for V3 variant TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvV3) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 15a5e08803..ae02f5ee77 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -382,7 +382,8 @@ template + bool DirectLoad = false, + index_t NumGroupsToMerge = 1> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 : public DeviceGroupedConvFwdMultipleABD(); + static_assert(NumGroupsToMerge >= 1); + static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; static constexpr bool isMultiD = DsDataType::Size() > 0; @@ -447,7 +450,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ConvForwardSpecialization, true /*SplitN*/, ADataType, - EDataType>; + EDataType, + NumGroupsToMerge>; using ComputePtrOffset = ComputePtrOffsetOfStridedBatch; @@ -784,8 +788,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 cde_element_op_{cde_element_op} { // A/B/E Batch/N Stride - compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0]; - compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0]; + compute_ptr_offset_of_groups_.BatchStrideA_ = + a_g_n_c_wis_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_; // p_as and p_bs are pointers @@ -796,7 +801,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static_for<0, NumDTensor, 1>{}([&](auto i) { using DLayout = remove_cvref_t>; // D batch stride - compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; @@ -816,7 +822,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); }); - compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = + e_g_n_k_wos_strides_[0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; if constexpr(is_NGCHW_GKCYX_NGKHW() || @@ -999,7 +1006,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); - gdy = arg.num_group_; + gdy = arg.num_group_ / NumGroupsToMerge; gdz = num_workgroups_per_Conv_N; index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; @@ -1499,6 +1506,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } } + if constexpr(NumGroupsToMerge > 1) + { + if(G % NumGroupsToMerge != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported! G % NumGroupsToMerge != 0: G=" << G + << ", NumGroupsToMerge=" << NumGroupsToMerge << std::endl; + } + return false; + } + } + if(get_device_name() == "gfx908") { // FIXME: re-enable fp64 when SWDEV-335738 is fixed @@ -1595,6 +1615,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } } } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + } // check vector access of A // FIXME: layout @@ -2106,6 +2142,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if constexpr(DirectLoad) { str << "_DirectLoad"; } + if constexpr (NumGroupsToMerge > 1) { + str << "_MergedGroups"; + } str << "<" << BlockSize << ", " @@ -2125,8 +2164,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 << "BlkGemmPipelineScheduler: " << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " << "BlkGemmPipelineVersion: " - << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] - << ">"; + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer]; + if constexpr (NumGroupsToMerge > 1) { + str << ", " << NumGroupsToMerge; + } + str << ">"; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index ab2821d989..4b91382d10 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1380,11 +1380,11 @@ struct TransformConvFwdToGemm else { const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( - make_tuple(K_, NumGroupsToMerge, ZYX_ * C_), - make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + make_tuple(NumGroupsToMerge, K_, ZYX_ * C_), + make_tuple(GStrideTensorB_, KStrideTensorB_, CStrideTensorB_)); return transform_tensor_descriptor( wei_gemmn_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)), make_pass_through_transform(ZYX_ * C_)), make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -1550,20 +1550,20 @@ struct TransformConvFwdToGemm else { const auto nhwo_groups_k_1_desc = - make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, 1, K_), make_tuple(NStrideTensorC_, HoStride_, WoStride_, GStrideTensorC_, - KStrideTensorC_, - GStrideTensorC_)); + GStrideTensorC_, + KStrideTensorC_)); // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( nhwo_groups_k_1_desc, make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(K_), - make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(K_)), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // We need only matrices from diagonal. X_or returns 0 for the same @@ -1577,13 +1577,13 @@ struct TransformConvFwdToGemm make_tuple(make_pass_through_transform(NDoHoWo), make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); // Merge To M, N return transform_tensor_descriptor( unmerged_padded_desc, make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), - make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, K_))), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 3b7ce0df3a..4ef55179a0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -4,6 +4,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -54,7 +55,13 @@ using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< // Instances with NumGroupsPerBatch > 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32>, + + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block GEMM| Block GEMM| In| Wei| Direct| Num| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| pipeline| pipeline| compute| compute| load| merged| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| scheduler| version| type| type| | groups| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, BF16, BF16, false, 8> // clang-format on >; @@ -75,7 +82,16 @@ using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, LoopScheduler::Default, 8>, + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block GEMM| Block GEMM| In| Wei| Direct| Num| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| pipeline| pipeline| compute| compute| load| merged| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| scheduler| version| type| type| | groups| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, BF16, BF16, false, 8> // clang-format on >; @@ -96,7 +112,13 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< // Instances with NumGroupsPerBatch > 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32>, + + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block GEMM| Block GEMM| In| Wei| Direct| Num| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| pipeline| pipeline| compute| compute| load| merged| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| scheduler| version| type| type| | groups| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16, F16, false, 8> // clang-format on >; @@ -120,9 +142,15 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, LoopScheduler::Default, 8>, + + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block GEMM| Block GEMM| In| Wei| Direct| Num| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| pipeline| pipeline| compute| compute| load| merged| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| scheduler| version| type| type| | groups| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16, F16, false, 8> // clang-format on >;