From 6ccfb817e447992d8254055572de693319b2e83d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 26 Mar 2025 21:13:38 +0100 Subject: [PATCH] Add support for GKCYX grouped conv fwd (#2015) * Add support for GKCYX grouped conv fwd * fixes * fix * changelog * Fixes [ROCm/composable_kernel commit: 54c81a1fcf75720b8993cac156d849c2ee17a057] --- CHANGELOG.md | 1 + ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 8 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 11 +- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 11 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 245 ++++++++++++----- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 216 ++++++++++----- .../device/impl/device_grouped_conv_utils.hpp | 33 ++- .../gpu/grid/gridwise_elementwise_2d.hpp | 47 ++-- .../transform_conv_ngchw_to_nhwgc.hpp | 259 ++++++++++++++++-- .../device_operation_instance_factory.hpp | 6 +- .../gpu/grouped_convolution_forward.hpp | 75 +++-- .../grouped_convolution_forward_comp_xdl.inc | 30 +- ...uped_convolution_forward_mem_inter_xdl.inc | 30 +- ...uped_convolution_forward_mem_intra_xdl.inc | 30 +- .../gpu/grouped_convolution_forward_xdl.inc | 49 ++++ ..._convolution_forward_xdl_merged_groups.inc | 16 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 40 +-- ..._ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp} | 12 +- ...l_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp} | 12 +- ...l_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp} | 8 +- ...l_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp | 64 ----- ...wd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp | 38 +++ ...fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp | 38 +++ ...fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp | 38 +++ ...w_gkcyx_ngkhw_bf16_mem_inter_instance.cpp} | 8 +- ...w_gkcyx_ngkhw_bf16_mem_intra_instance.cpp} | 8 +- ...hw_gkcyx_ngkhw_f16_mem_inter_instance.cpp} | 8 +- ...hw_gkcyx_ngkhw_f16_mem_intra_instance.cpp} | 8 +- ...hw_gkcyx_ngkhw_f32_mem_inter_instance.cpp} | 8 +- ...hw_gkcyx_ngkhw_f32_mem_intra_instance.cpp} | 8 +- ...hw_gkyxc_ngkhw_int8_mem_inter_instance.cpp | 39 --- ...hw_gkyxc_ngkhw_int8_mem_intra_instance.cpp | 39 --- ...roups_ngchw_gkcyx_ngkhw_bf16_instance.cpp} | 10 +- ...groups_ngchw_gkcyx_ngkhw_f16_instance.cpp} | 10 +- ...groups_ngchw_gkcyx_ngkhw_f32_instance.cpp} | 10 +- ...groups_ngchw_gkyxc_ngkhw_int8_instance.cpp | 48 ---- profiler/src/profile_grouped_conv_fwd.cpp | 36 ++- script/convert_miopen_driver_to_profiler.py | 11 +- .../test_grouped_convnd_fwd.cpp | 7 +- 39 files changed, 1005 insertions(+), 570 deletions(-) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/{device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/{device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/{device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp} (89%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/{device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/{device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/{device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/{device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/{device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/{device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instance.cpp} (92%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/{device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp => device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/{device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp => device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/{device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp => device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instance.cpp} (93%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instance.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index d7b1389dcb..0d07abfc24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data +* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). ### Optimized diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index d657c4447e..38e9e3c3d5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -496,11 +496,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 input_right_pads_{input_right_pads} { std::array a_g_n_k_wos_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths, - a_g_n_k_wos_strides); + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, + a_g_n_k_wos_strides); std::array e_g_n_c_wis_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeStrides(e_g_n_c_wis_lengths, - e_g_n_c_wis_strides); + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths, + e_g_n_c_wis_strides); // populate Ds pointer static_for<0, NumDTensor, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 86e7927f71..033b84aafc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -534,11 +534,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle begin(output_spatial_lengths_)); std::array b_g_n_c_wis_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths, - b_g_n_c_wis_strides); + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); std::array a_g_n_k_wos_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths, - a_g_n_k_wos_strides); + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, + a_g_n_k_wos_strides); const auto descs = conv_to_gemm_transformer_v2 @@ -1425,11 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Different data type for A and B is not supported auto kernel_transpose = kernel_elementwise_dual, ck::Tuple, ck::Tuple, ck::Tuple, ck::Tuple, + ck::Tuple, + ck::Tuple, ck::Tuple, Block2TileMapElementwise, Block2TileMapElementwise, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index e98f60a245..6d2a354ce3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -453,11 +453,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle begin(output_spatial_lengths_)); std::array b_g_n_c_wis_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths, - b_g_n_c_wis_strides); + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); std::array a_g_n_k_wos_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths, - a_g_n_k_wos_strides); + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, + a_g_n_k_wos_strides); const auto descs = conv_to_gemm_transformer @@ -641,11 +641,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle // Different data type for A and B is not supported auto kernel_transpose = kernel_elementwise_dual, ck::Tuple, ck::Tuple, ck::Tuple, ck::Tuple, + ck::Tuple, + ck::Tuple, ck::Tuple, Block2TileMapElementwise, Block2TileMapElementwise, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 567ac7f3c9..69913163f0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -314,8 +314,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr bool isMultiB = is_detected::value; // NGCHW is not supported for multiAB - static_assert(!(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) || + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || !(isMultiA || isMultiB)); static constexpr index_t NumATensor = GetNumABTensors(); @@ -355,11 +355,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKYXC_NGKHW(), + is_NGCHW_NGKHW(), ctc::NHWGC, - std::conditional_t(), - ctc::NDHWGC, - ALay>>; + std::conditional_t(), ctc::NDHWGC, ALay>>; const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -373,8 +371,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_NGKHW(), + ctc::GKYXC, + std::conditional_t(), ctc::GKZYXC, BLay>>; + const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -387,11 +391,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKYXC_NGKHW(), + is_NGCHW_NGKHW(), ctc::NHWGK, - std::conditional_t(), - ctc::NDHWGK, - ELay>>; + std::conditional_t(), ctc::NDHWGK, ELay>>; const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -491,6 +493,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock; using GridwiseElementwiseInputTranspose = @@ -511,6 +520,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle I1, I0>; + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence<1>, + Sequence, + I0, + I1>; + using GridwiseElementwiseOutputTranspose = GridwiseElementwise, Tuple, @@ -558,14 +585,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, - a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides( + a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( a_g_n_c_wis_lengths, a_g_n_c_wis_strides)}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, - b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, - e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides( + e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, conv_filter_strides_{conv_filter_strides}, conv_filter_dilations_{conv_filter_dilations}, @@ -744,8 +772,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { // Use not modified base strides a_in_transpose_desc_ = @@ -755,6 +783,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + e_in_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( e_g_n_k_wos_lengths, e_g_n_k_wos_strides); @@ -764,6 +799,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; } @@ -771,25 +808,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle std::size_t GetWorkspaceATensorSizeBytes() const { - const long_index_t a_acum = ck::accumulate_n( - a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); - return sizeof(ADataType) * a_acum; - } - - std::size_t GetWorkspaceETensorSizeBytes() const - { - const long_index_t e_accum = ck::accumulate_n( - e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); - return sizeof(EDataType) * e_accum; - } - - std::size_t GetWorkspaceSizeBytes() const - { - // Transpose require workspace for A and B - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { - return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes(); + const long_index_t a_acum = ck::accumulate_n( + a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128; } else { @@ -797,6 +822,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const long_index_t b_acum = ck::accumulate_n( + b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + const long_index_t e_accum = ck::accumulate_n( + e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + return sizeof(EDataType) * e_accum; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + void Print() const { std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; @@ -849,10 +911,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // block-to-e-tile map Block2ETileMap block_2_etile_map_; Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, - elementwise_block_2_ctile_map_transpose_e_; + elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_; NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + GKCYXTransposeDescType b_in_transpose_desc_; + GKYXCTransposeDescType b_out_transpose_desc_; // for computing batch offset ComputePtrOffsetOfStridedBatch @@ -942,14 +1006,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle else { const ADataType* p_a_grid = arg.p_as_grid_.At(I0); + const BDataType* p_b_grid = arg.p_bs_grid_.At(I0); EDataType* p_e_grid = arg.p_e_grid_; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + p_e_grid = type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + else if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) { p_a_grid = type_convert(arg.p_workspace_); p_e_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + (arg.GetWorkspaceATensorSizeBytes() + + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); } const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< @@ -978,8 +1056,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - p_a_grid, // Pass just A descriptor instead of tuple - arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple + p_a_grid, + p_b_grid, arg.p_ds_grid_, p_e_grid, arg.a_element_op_, @@ -1009,50 +1087,71 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { float avg_time = 0.f; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { - const index_t grid_size = + const index_t a_grid_size = arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( arg.a_in_transpose_desc_); + const index_t b_grid_size = + (is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_) + : 0; // Dont run transpose B if not needed ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); - auto kernel_transpose = kernel_elementwise, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Block2TileMapElementwise, - element_wise::PassThrough>; + auto kernel_transpose = kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; avg_time += launch_and_time_kernel(stream_config, kernel_transpose, - dim3(grid_size), + dim3(a_grid_size + b_grid_size), dim3(ElementwiseBlocksize), 0, make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), make_tuple(arg.p_as_grid_.At(I0)), + make_tuple(arg.p_bs_grid_.At(I0)), make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), arg.elementwise_block_2_ctile_map_transpose_a_, - element_wise::PassThrough{}); + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size); } avg_time += RunGemm(arg, stream_config); - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { const index_t grid_size = arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( arg.e_in_transpose_desc_); - const EDataType* p_e_out_grid = + const EDataType* p_e_in_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); - EDataType* p_e_in_grid = arg.p_e_grid_; + EDataType* p_e_out_grid = arg.p_e_grid_; auto kernel_transpose = kernel_elementwise, @@ -1069,8 +1168,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle 0, make_tuple(arg.e_in_transpose_desc_), make_tuple(arg.e_out_transpose_desc_), - make_tuple(p_e_out_grid), make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), arg.elementwise_block_2_ctile_map_transpose_e_, element_wise::PassThrough{}); } @@ -1114,12 +1213,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // check if it's 1x1, stride=1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; - if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) { return false; } @@ -1131,11 +1230,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // check if it's 1x1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; - const index_t LeftPad = arg.input_left_pads_[i]; - const index_t RightPad = arg.input_right_pads_[i]; + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; - if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) { return false; } @@ -1156,10 +1255,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return false; } } - if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK()) - { - return false; - } } if constexpr(NumGroupsToMerge > 1) @@ -1173,7 +1268,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return false; } if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK() || - is_NGCSpatial_GKSpatial_NGKSpatial())) + is_NGCSpatial_GKSpatial_NGKSpatial() || + is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW())) { return false; } @@ -1194,7 +1291,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // If not possible, check access per G if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) && (is_NSpatialGC_GKSpatial_NSpatialGK() || - is_NGCSpatial_GKSpatial_NGKSpatial()) && + is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) && G % ABlockTransferSrcScalarPerVector == 0)) { return false; @@ -1212,7 +1310,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v) + is_same_v || is_same_v || + is_same_v || is_same_v) { if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) @@ -1270,8 +1369,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } }); - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 5e9ecfd225..e91496f6a5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -325,9 +325,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKYXC_NGKHW(), + is_NGCHW_GKCYX_NGKHW(), ctc::NHWGC, - std::conditional_t(), + std::conditional_t(), ctc::NDHWGC, ALay>>; @@ -353,8 +353,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static auto MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_GKCYX_NGKHW(), + ctc::GKYXC, + std::conditional_t(), + ctc::GKZYXC, + BLay>>; + const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -377,9 +385,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKYXC_NGKHW(), + is_NGCHW_GKCYX_NGKHW(), ctc::NHWGK, - std::conditional_t(), + std::conditional_t(), ctc::NDHWGK, ELay>>; @@ -426,6 +434,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock; using GridwiseElementwiseInputTranspose = @@ -446,6 +461,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 I1, I0>; + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence<1>, + Sequence, + I0, + I1>; + using GridwiseElementwiseOutputTranspose = GridwiseElementwise, Tuple, @@ -508,12 +541,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 p_b_grid_{}, p_e_grid_{static_cast(p_e)}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, - a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides( + a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( a_g_n_c_wis_lengths, a_g_n_c_wis_strides)}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, - b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, - e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides( + e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, conv_filter_strides_{conv_filter_strides}, conv_filter_dilations_{conv_filter_dilations}, @@ -559,8 +593,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 e_grid_desc_mblock_mperblock_nblock_nperblock_ = MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) { // Use not modified base strides a_in_transpose_desc_ = @@ -570,9 +604,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + e_in_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; e_out_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( e_g_n_k_wos_lengths, e_g_n_k_wos_strides); @@ -586,25 +629,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 std::size_t GetWorkspaceATensorSizeBytes() const { - const long_index_t a_acum = ck::accumulate_n( - a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); - return sizeof(ADataType) * a_acum; - } - - std::size_t GetWorkspaceETensorSizeBytes() const - { - const long_index_t e_accum = ck::accumulate_n( - e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); - return sizeof(EDataType) * e_accum; - } - - std::size_t GetWorkspaceSizeBytes() const - { - // Transpose require workspace for A and B - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { - return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes(); + const long_index_t a_acum = ck::accumulate_n( + a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128; } else { @@ -612,6 +643,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } } + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const long_index_t b_acum = ck::accumulate_n( + b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + const long_index_t e_accum = ck::accumulate_n( + e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + return sizeof(EDataType) * e_accum; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + void Print() const { std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; @@ -661,10 +729,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // block-to-e-tile map Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, - elementwise_block_2_ctile_map_transpose_e_; + elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_; NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + GKCYXTransposeDescType b_in_transpose_desc_; + GKYXCTransposeDescType b_out_transpose_desc_; }; // Invoker @@ -702,18 +772,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; EDataType* p_e_grid = arg.p_e_grid_; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) { p_a_grid = type_convert(arg.p_workspace_); - p_e_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); } typename GridwiseGemm::Argument gemm_arg{ - p_a_grid, arg.p_b_grid_, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1}; + p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1}; const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) @@ -1012,50 +1087,68 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { float avg_time = 0.f; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) { - const index_t grid_size = + const index_t a_grid_size = arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( arg.a_in_transpose_desc_); + const index_t b_grid_size = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); - auto kernel_transpose = kernel_elementwise, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Block2TileMapElementwise, - element_wise::PassThrough>; + auto kernel_transpose = kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; avg_time += launch_and_time_kernel(stream_config, kernel_transpose, - dim3(grid_size), + dim3(a_grid_size + b_grid_size), dim3(ElementwiseBlocksize), 0, make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), arg.elementwise_block_2_ctile_map_transpose_a_, - element_wise::PassThrough{}); + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size); } avg_time += RunGemm(arg, stream_config); - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) { const index_t grid_size = arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( arg.e_in_transpose_desc_); - const EDataType* p_e_out_grid = + const EDataType* p_e_in_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); - EDataType* p_e_in_grid = arg.p_e_grid_; + EDataType* p_e_out_grid = arg.p_e_grid_; auto kernel_transpose = kernel_elementwise, @@ -1072,8 +1165,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 0, make_tuple(arg.e_in_transpose_desc_), make_tuple(arg.e_out_transpose_desc_), - make_tuple(p_e_out_grid), make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), arg.elementwise_block_2_ctile_map_transpose_e_, element_wise::PassThrough{}); } @@ -1118,12 +1211,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // check if it's 1x1, stride=1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; - if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) { return false; } @@ -1135,11 +1228,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // check if it's 1x1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; - const index_t LeftPad = arg.input_left_pads_[i]; - const index_t RightPad = arg.input_right_pads_[i]; + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; - if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) { return false; } @@ -1171,7 +1264,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v) + is_same_v || is_same_v || + is_same_v || is_same_v) { if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) @@ -1184,8 +1278,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return false; } - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) { if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp index 3bcd8859aa..5de429f9e5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -59,6 +59,22 @@ constexpr bool is_NGCHW_GKYXC_NGKHW() is_same_v && is_same_v; } + +template +constexpr bool is_NGCHW_GKCYX_NGKHW() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCHW_NGKHW() +{ + return is_same_v && + is_same_v; +} + // 3d template constexpr bool is_NDHWGC_GKZYXC_NDHWGK() @@ -84,6 +100,21 @@ constexpr bool is_NGCDHW_GKZYXC_NGKDHW() is_same_v; } +template +constexpr bool is_NGCDHW_GKCZYX_NGKDHW() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCDHW_NGKDHW() +{ + return is_same_v && + is_same_v; +} + template constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK() { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index 856ba22146..0edfc9b0ee 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -41,13 +41,16 @@ __global__ void elementwise_op); } -template @@ -55,14 +58,14 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a, + kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, - const InDataTypePointerTuple p_in_global_tuple_a, - const InDataTypePointerTuple p_in_global_tuple_b, - const OutDataTypePointerTuple p_out_global_tuple_a, - const OutDataTypePointerTuple p_out_global_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, @@ -70,23 +73,23 @@ __global__ void { if(get_block_1d_id() < a_grid_size) { - GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a, - out_grid_desc_tuple_a, - p_in_global_tuple_a, - p_out_global_tuple_a, - block_2_tile_map_a, - elementwise_op, - get_block_1d_id()); + GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a, + out_grid_desc_tuple_a, + p_in_global_tuple_a, + p_out_global_tuple_a, + block_2_tile_map_a, + elementwise_op, + get_block_1d_id()); } else { - GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b, - out_grid_desc_tuple_b, - p_in_global_tuple_b, - p_out_global_tuple_b, - block_2_tile_map_b, - elementwise_op, - get_block_1d_id() - a_grid_size); + GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b, + out_grid_desc_tuple_b, + p_in_global_tuple_b, + p_out_global_tuple_b, + block_2_tile_map_b, + elementwise_op, + get_block_1d_id() - a_grid_size); } } diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp index 2bf1c40a12..7bf52cb229 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp @@ -28,9 +28,10 @@ struct TransformConvNGCHWToNHWGC static constexpr auto I5 = Number<5>{}; template ::type = false> - static auto MakeNGCHWTransposeDesc(std::array g_n_c_wis_lengths, - std::array g_n_c_wis_strides, - const index_t split_n_size = 1) + static auto + MakeNGCHWTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) { const index_t& G = g_n_c_wis_lengths[I0]; const index_t N = g_n_c_wis_lengths[I1] / split_n_size; @@ -55,9 +56,10 @@ struct TransformConvNGCHWToNHWGC } template ::type = false> - static auto MakeNHWGCTransposeDesc(std::array g_n_c_wis_lengths, - std::array g_n_c_wis_strides, - const index_t split_n_size = 1) + static auto + MakeNHWGCTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) { const index_t& G = g_n_c_wis_lengths[I0]; const index_t N = g_n_c_wis_lengths[I1] / split_n_size; @@ -82,9 +84,10 @@ struct TransformConvNGCHWToNHWGC } template ::type = false> - static auto MakeNGCHWTransposeDesc(std::array g_n_c_wis_lengths, - std::array g_n_c_wis_strides, - const index_t split_n_size = 1) + static auto + MakeNGCHWTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) { const index_t& G = g_n_c_wis_lengths[I0]; const index_t N = g_n_c_wis_lengths[I1] / split_n_size; @@ -111,9 +114,10 @@ struct TransformConvNGCHWToNHWGC } template ::type = false> - static auto MakeNHWGCTransposeDesc(std::array g_n_c_wis_lengths, - std::array g_n_c_wis_strides, - const index_t split_n_size = 1) + static auto + MakeNHWGCTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) { const index_t& G = g_n_c_wis_lengths[I0]; const index_t N = g_n_c_wis_lengths[I1] / split_n_size; @@ -140,9 +144,10 @@ struct TransformConvNGCHWToNHWGC } template ::type = false> - static auto MakeNGCHWTransposeDesc(std::array g_n_c_wis_lengths, - std::array g_n_c_wis_strides, - const index_t split_n_size = 1) + static auto + MakeNGCHWTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) { const index_t& G = g_n_c_wis_lengths[I0]; const index_t N = g_n_c_wis_lengths[I1] / split_n_size; @@ -172,9 +177,10 @@ struct TransformConvNGCHWToNHWGC } template ::type = false> - static auto MakeNHWGCTransposeDesc(std::array g_n_c_wis_lengths, - std::array g_n_c_wis_strides, - const index_t split_n_size = 1) + static auto + MakeNHWGCTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) { const index_t& G = g_n_c_wis_lengths[I0]; const index_t N = g_n_c_wis_lengths[I1] / split_n_size; @@ -203,11 +209,185 @@ struct TransformConvNGCHWToNHWGC merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); } - static auto TransposeStrides(const std::array& g_n_c_wis_lengths, - const std::array& g_n_c_wis_strides) + template ::type = false> + static auto + MakeGKCYXTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) { - if constexpr(device::is_NGCHW_GKYXC_NGKHW() || - device::is_NGCDHW_GKZYXC_NGKDHW()) + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& X = g_k_c_wis_lengths[I3]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t& KStride = g_k_c_wis_strides[I1]; + const index_t& CStride = g_k_c_wis_strides[I2]; + const index_t& XStride = g_k_c_wis_strides[I3]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride)); + const auto merged_desc = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(make_tuple(G, K, X)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKYXCTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& X = g_k_c_wis_lengths[I3]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t KStride = g_k_c_wis_strides[I1]; + const index_t CStride = 1; + const index_t XStride = C; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride)); + const auto merged_desc = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(make_tuple(G, K, X)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKCYXTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& Y = g_k_c_wis_lengths[I3]; + const index_t& X = g_k_c_wis_lengths[I4]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t& KStride = g_k_c_wis_strides[I1]; + const index_t& CStride = g_k_c_wis_strides[I2]; + const index_t& YStride = g_k_c_wis_strides[I3]; + const index_t& XStride = g_k_c_wis_strides[I4]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(G, K, Y, X)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3, 4>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKYXCTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& Y = g_k_c_wis_lengths[I3]; + const index_t& X = g_k_c_wis_lengths[I4]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t KStride = g_k_c_wis_strides[I1]; + const index_t CStride = 1; + const index_t YStride = X * C; + const index_t XStride = C; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(G, K, Y, X)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3, 4>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKCYXTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& Z = g_k_c_wis_lengths[I3]; + const index_t& Y = g_k_c_wis_lengths[I4]; + const index_t& X = g_k_c_wis_lengths[I5]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t& KStride = g_k_c_wis_strides[I1]; + const index_t& CStride = g_k_c_wis_strides[I2]; + const index_t& ZStride = g_k_c_wis_strides[I3]; + const index_t& YStride = g_k_c_wis_strides[I4]; + const index_t& XStride = g_k_c_wis_strides[I5]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, Z, Y, X), + make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(G, K, Z, Y, X)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3, 4, 5>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKYXCTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& Z = g_k_c_wis_lengths[I3]; + const index_t& Y = g_k_c_wis_lengths[I4]; + const index_t& X = g_k_c_wis_lengths[I5]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t KStride = g_k_c_wis_strides[I1]; + const index_t CStride = 1; + const index_t ZStride = Y * X * C; + const index_t YStride = X * C; + const index_t XStride = C; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, Z, Y, X), + make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(G, K, Z, Y, X)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3, 4, 5>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + static auto TransposeInOutStrides(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides) + { + if constexpr(device::is_NGCHW_NGKHW() || + device::is_NGCDHW_NGKDHW()) { std::array g_n_c_wis_strides_transposed; const auto G = g_n_c_wis_lengths[I0]; @@ -236,6 +416,41 @@ struct TransformConvNGCHWToNHWGC return g_n_c_wis_strides; } } + + static auto + TransposeWeiStrides(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + if constexpr(device::is_NGCHW_GKCYX_NGKHW() || + device::is_NGCDHW_GKCZYX_NGKDHW()) + { + std::array g_k_c_wis_strides_transposed = g_k_c_wis_strides; + const index_t C = g_k_c_wis_lengths[I2]; + + if constexpr(NDimSpatial == 2) + { + const index_t X = g_k_c_wis_lengths[I4]; + g_k_c_wis_strides_transposed[I2] = 1; + g_k_c_wis_strides_transposed[I3] = X * C; + g_k_c_wis_strides_transposed[I4] = C; + } + else if constexpr(NDimSpatial == 3) + { + const index_t Y = g_k_c_wis_lengths[I4]; + const index_t X = g_k_c_wis_lengths[I5]; + g_k_c_wis_strides_transposed[I2] = 1; + g_k_c_wis_strides_transposed[I3] = Y * X * C; + g_k_c_wis_strides_transposed[I4] = X * C; + g_k_c_wis_strides_transposed[I5] = C; + } + return g_k_c_wis_strides_transposed; + } + else + { + // transpose not needed + return g_k_c_wis_strides; + } + } }; } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 4a44c425aa..c3fd04ba35 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -71,6 +71,10 @@ using GKXC = ck::tensor_layout::convolution::GKXC; using GKYXC = ck::tensor_layout::convolution::GKYXC; using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using GKCX = ck::tensor_layout::convolution::GKCX; +using GKCYX = ck::tensor_layout::convolution::GKCYX; +using GKCZYX = ck::tensor_layout::convolution::GKCZYX; + using GNWK = ck::tensor_layout::convolution::GNWK; using GNHWK = ck::tensor_layout::convolution::GNHWK; using GNDHWK = ck::tensor_layout::convolution::GNDHWK; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 01415c2ddd..c2e1337737 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -272,20 +272,20 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + is_same_v && is_same_v) { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances( + add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances( + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances( op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances( + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances( op_ptrs); } #endif @@ -294,13 +294,13 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances( + add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instances( op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances( + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instances( op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances( + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instances( op_ptrs); } #endif @@ -311,14 +311,46 @@ struct DeviceOperationInstanceFactory && is_same_v) { - add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances( + add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instances( op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instances( + op_ptrs); + } +#endif + } + + // layout NGCHW/GKYXC/NGKHW + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_INT8 @@ -326,14 +358,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances( - op_ptrs); add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances( - op_ptrs); } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc index 9a83e36b99..1f924737cd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc @@ -73,12 +73,12 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( PassThrough>>>& instances); #endif -// grouped conv2d forward, NGCHW/GKYXC/NGKHW +// grouped conv2d forward, NGCHW/GKCYX/NGKHW #ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances( +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances( std::vector>>& instances); #endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances( - std::vector>>& instances); -#endif - #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc index 662fadadcf..3900c7a0fb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc @@ -73,12 +73,12 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance PassThrough>>>& instances); #endif -// grouped conv2d forward, NGCHW/GKYXC/NGKHW +// grouped conv2d forward, NGCHW/GKCYX/NGKHW #ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances( +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instances( std::vector>>& instances); #endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances( - std::vector>>& instances); -#endif - #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc index f283fe8550..b7815f5023 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc @@ -73,12 +73,12 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance PassThrough>>>& instances); #endif -// grouped conv2d forward, NGCHW/GKYXC/NGKHW +// grouped conv2d forward, NGCHW/GKCYX/NGKHW #ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances( +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instances( std::vector>>& instances); #endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances( - std::vector>>& instances); -#endif - #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index c977c89c94..b934b9aef6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -252,6 +252,55 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances( PassThrough>>>& instances); #endif +// grouped conv2d forward, NGCHW/GKCYX/NGKHW +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_BF16 // grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc index a81e1e07ba..966b883301 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc @@ -24,10 +24,10 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_inst PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances( +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances( +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances( +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances( +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_int8_instances( std::vector{}); @@ -39,7 +39,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances( instances, device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, NGCHW, - GKYXC, + GKCYX, Empty_Tuple, NGKHW, ConvFwdDefault>{}); @@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances( instances, device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, NGCHW, - GKYXC, + GKCYX, Empty_Tuple, NGKHW, ConvFwdDefault>{}); diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp similarity index 92% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp index ea6ba831f1..13e0e91f97 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -10,10 +10,10 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances( +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances( std::vector{}); @@ -39,7 +39,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances( instances, device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, NGCHW, - GKYXC, + GKCYX, Empty_Tuple, NGKHW, ConvFwdDefault>{}); @@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances( instances, device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, NGCHW, - GKYXC, + GKCYX, Empty_Tuple, NGKHW, ConvFwdDefault>{}); diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp similarity index 89% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp index ba3e982e99..3a93c16138 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -9,10 +9,10 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances( +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances( std::vector{}); diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp deleted file mode 100644 index 8f0a5ca425..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp +++ /dev/null @@ -1,64 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/host_utility/device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_int8_comp_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - } -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp new file mode 100644 index 0000000000..6c5d9b5b94 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp new file mode 100644 index 0000000000..f1ccad2add --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp new file mode 100644 index 0000000000..de7e416e48 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instance.cpp similarity index 92% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instance.cpp index 88b5f30da5..d57c67ba07 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -9,10 +9,10 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances( +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instances( std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault, - Interwave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instance.cpp deleted file mode 100644 index 217f57d879..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instance.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault, - Intrawave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp similarity index 93% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp index 14f00d8e88..a8ebcaa6b4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" @@ -9,10 +9,10 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances( +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instances( std::vector{}); @@ -36,7 +36,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_inst instances, device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, NGCHW, - GKYXC, + GKCYX, Empty_Tuple, NGKHW, ConvFwd3x3>{}); diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instance.cpp similarity index 93% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instance.cpp index 3ae1ba3d05..5571e11aa0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" @@ -9,10 +9,10 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances( +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instances( std::vector{}); @@ -36,7 +36,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_insta instances, device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, NGCHW, - GKYXC, + GKCYX, Empty_Tuple, NGKHW, ConvFwd3x3>{}); diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instance.cpp similarity index 93% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instance.cpp index cc570568f3..252b09a1c4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" @@ -9,10 +9,10 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances( +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( std::vector{}); @@ -36,7 +36,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta instances, device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2, NGCHW, - GKYXC, + GKCYX, Empty_Tuple, NGKHW, ConvFwd3x3>{}); diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instance.cpp deleted file mode 100644 index c66d48ed7a..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instance.cpp +++ /dev/null @@ -1,48 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_int8_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_int8_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwd3x3>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 7faf573dbf..9ee05d1304 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,7 @@ enum struct ConvLayout GNHWC_GKYXC_GNHWK, // 0 NHWGC_GKYXC_NHWGK, // 1 NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 }; enum struct ConvDataType @@ -52,11 +53,13 @@ static void print_helper_msg() << " 5: Input bf8, Weight bf8, Output fp8\n" << " 6: Input fp8, Weight bf8, Output fp8\n" << " 7: Input bf8, Weight fp8, Output fp8)\n" - << "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n" - << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" - << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " - "G, K, Ho, Wo]\n" + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" << "arg5: verification (0: no, 1: yes)\n" << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg7: print tensor value (0: no; 1: yes)\n" @@ -110,6 +113,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using GKYXC = ck::tensor_layout::convolution::GKYXC; using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + // using GKCX = ck::tensor_layout::convolution::GKXC; + using GKCYX = ck::tensor_layout::convolution::GKCYX; + // using GKCZYX = ck::tensor_layout::convolution::GKZYXC; + using GNWK = ck::tensor_layout::convolution::GNWK; using GNHWK = ck::tensor_layout::convolution::GNHWK; using GNDHWK = ck::tensor_layout::convolution::GNDHWK; @@ -302,6 +309,25 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py index a9dec2ec95..9bb668e164 100644 --- a/script/convert_miopen_driver_to_profiler.py +++ b/script/convert_miopen_driver_to_profiler.py @@ -19,17 +19,20 @@ def init_const_args(args): def run_ck_profiler_cmd(cmd): print("ckProfiler command:") - print(cmd) + cmd_concatenated_str = "" + for arg in cmd: + cmd_concatenated_str += arg + " " + print(cmd_concatenated_str) subprocess.run(cmd) def parse_layouts(args): if args.in_layout == "NCW" or args.in_layout == "NCHW" or \ args.in_layout == "NCDHW": - if args.ck_profier_op == "grouped_conv_bwd_weight": - args.layout = 3 - elif args.ck_profier_op == "grouped_conv_bwd_data" or \ + if args.ck_profier_op == "grouped_conv_bwd_weight" or \ args.ck_profier_op == "grouped_conv_fwd": + args.layout = 3 + elif args.ck_profier_op == "grouped_conv_bwd_data": args.layout = 2 else: print('Not supported layout for this op') diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 25481e0d7f..43b77641d1 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -65,7 +65,10 @@ using KernelTypes2d = ::testing::Types, std::tuple, std::tuple, std::tuple, - std::tuple>; + std::tuple, + std::tuple, + std::tuple, + std::tuple>; using KernelTypes3d = ::testing::Types, std::tuple,