From 69a6b563f9836ff40e1278ac1199b64d667aa291 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 6 Aug 2024 10:06:10 +0200 Subject: [PATCH] Add Grouped Conv Fwd Large Tensor kernel (#1432) * Support 64 bit indexing * Add new grouped conv fwd kernel for large tensors * Add instances large tensor * Fixes for transform conv to gemm * Fixes * fixes * Remove not needed instances * examples fixes * Remove not need ds arrays * Fix tests * Add 2GB check in gridwise dl * Fixes [ROCm/composable_kernel commit: 4ec5c52a0c01e9b34ae9c3918a0c9372075e5852] --- .../common.hpp | 4 +- .../convnd_bwd_data_common.hpp | 45 +- .../device_grouped_conv_fwd_multiple_abd.hpp | 23 + ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 87 +- .../impl/device_column_to_image_impl.hpp | 4 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 162 ++- ...ice_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp | 12 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 171 ++- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 164 ++- ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 14 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 166 ++- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 1054 +++++++++++++++++ .../impl/device_image_to_column_impl.hpp | 4 +- .../gpu/grid/gridwise_gemm_dl_multiple_d.hpp | 11 +- .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 11 +- .../transform_conv_fwd_to_gemm.hpp | 506 ++++++-- include/ck/utility/type_convert.hpp | 22 +- .../cpu/reference_column_to_image.hpp | 116 +- .../cpu/reference_conv_bwd_data.hpp | 24 +- .../cpu/reference_conv_bwd_weight.hpp | 24 +- .../cpu/reference_conv_fwd.hpp | 26 +- .../cpu/reference_image_to_column.hpp | 99 +- ...ped_conv_fwd_xdl_large_tensor_instance.hpp | 93 ++ .../gpu/grouped_convolution_forward.hpp | 13 + ...d_convolution_forward_xdl_large_tensor.inc | 112 ++ .../library/utility/convolution_parameter.hpp | 40 +- .../ck/library/utility/host_tensor.hpp | 21 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 5 + ...tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 39 + ..._tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 39 + ..._tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 39 + .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 4 + ...sor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 39 + ...nsor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 39 + ...nsor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 39 + library/src/utility/convolution_parameter.cpp | 96 +- .../profiler/profile_conv_bwd_data_impl.hpp | 45 +- .../profiler/profile_conv_fwd_impl.hpp | 45 +- .../profile_grouped_conv_fwd_impl.hpp | 23 +- profiler/src/profile_grouped_conv_fwd.cpp | 83 +- test/conv_util/conv_util.cpp | 56 +- .../test_grouped_convnd_fwd.cpp | 52 +- 42 files changed, 3220 insertions(+), 451 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 137b0d1ff0..7e3130a1a1 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc, inline HostTensorDescriptor make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size) { - std::vector dimensions{problem_size.G_, problem_size.N_}; + std::vector dimensions{problem_size.G_, problem_size.N_}; ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions)); diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp index 4a9d16c5c3..d219df0245 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification, // reset input to zero in_device_buf.SetZero(); + std::vector input_spatial_lengths_i32(NDimSpatial); + std::vector filter_spatial_lengths_i32(NDimSpatial); + std::vector output_spatial_lengths_i32(NDimSpatial); + std::vector conv_filter_strides_i32(NDimSpatial); + std::vector conv_filter_dilations_i32(NDimSpatial); + std::vector input_left_pads_i32(NDimSpatial); + std::vector input_right_pads_i32(NDimSpatial); + + for(ck::index_t d = 0; d < NDimSpatial; d++) + { + input_spatial_lengths_i32[d] = + static_cast(conv_param.input_spatial_lengths_[d]); + filter_spatial_lengths_i32[d] = + static_cast(conv_param.filter_spatial_lengths_[d]); + output_spatial_lengths_i32[d] = + static_cast(conv_param.GetOutputSpatialLengths()[d]); + conv_filter_strides_i32[d] = static_cast(conv_param.conv_filter_strides_[d]); + conv_filter_dilations_i32[d] = + static_cast(conv_param.conv_filter_dilations_[d]); + input_left_pads_i32[d] = static_cast(conv_param.input_left_pads_[d]); + input_right_pads_i32[d] = static_cast(conv_param.input_right_pads_[d]); + } + // do GEMM auto conv = DeviceConvNdBwdDataInstance{}; auto invoker = conv.MakeInvoker(); @@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification, conv.MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.N_, - conv_param.K_, - conv_param.C_, - conv_param.input_spatial_lengths_, - conv_param.filter_spatial_lengths_, - conv_param.GetOutputSpatialLengths(), - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, + static_cast(conv_param.N_), + static_cast(conv_param.K_), + static_cast(conv_param.C_), + input_spatial_lengths_i32, + filter_spatial_lengths_i32, + output_spatial_lengths_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, in_element_op, wei_element_op, out_element_op); diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp index 31e8d639ad..184efbbd68 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -126,6 +126,29 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator const BElementwiseOperation& b_element_op, const CDEElementwiseOperation& cde_element_op) = 0; + virtual std::unique_ptr + MakeArgumentPointer(APointers p_a, + BPointers p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) = 0; + virtual std::unique_ptr MakeInvokerPointer() = 0; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 33398a1c0b..180e32c8b6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using GemmToConvFwdTransformer = TransformConvFwdToGemm; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template __host__ __device__ static auto - MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template __host__ __device__ static auto - MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template __host__ __device__ static auto - MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Shape of Ds and E must be aligned. Strides can be different. // Pass e_g_n_k_wos_lengths for logical broadcast. - static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { @@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } // desc for problem definition - constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using AGridDesc_M_K = remove_cvref_t(dummy_conv_to_gemm_transformer))>; using BGridDesc_N_K = @@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // D batch stride compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; - GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, @@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // tensor descriptors for problem definiton index_t num_group_; - GemmToConvFwdTransformer conv_to_gemm_transformer_; + ConvToGemmFwdTransformer conv_to_gemm_transformer_; AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; @@ -943,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle b_element_op, cde_element_op}; } + + static __device__ __host__ auto MakeArgument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index e7ac9c0314..e4203e0313 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl static constexpr auto spatial_offset = Number<3>{}; - using GemmToConvFwdTransformer = + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{ @@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl : independent_filter_stride; } - GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, + ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, image_g_n_c_wis_strides, b_g_k_c_xs_lengths, {}, // not needed for A Descriptor diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 6ee086d09e..65b7b6cb7a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using GemmToConvFwdTransformer = TransformConvFwdToGemm; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; template static auto - MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK template static auto - MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK } template - static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK return out_gemmm_gemmn_desc; } - static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { @@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK } // desc for problem definition - constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using AGridDesc_AK0_M_AK1 = remove_cvref_t( dummy_conv_to_gemm_transformer))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t( @@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK using DLayout = remove_cvref_t>; using DDataType = remove_cvref_t>; - GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, @@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK // tensor descriptors for problem definiton index_t num_group_; - GemmToConvFwdTransformer conv_to_gemm_transformer_; + ConvToGemmFwdTransformer conv_to_gemm_transformer_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; @@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK cde_element_op}; } + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + static auto MakeInvoker() { return Invoker{}; } std::unique_ptr MakeArgumentPointer( @@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK cde_element_op); } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + std::unique_ptr MakeInvokerPointer() override { return std::make_unique(Invoker{}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 0a58cd0c88..50e171e503 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd{}; static constexpr auto I3 = Number<3>{}; - using GemmToConvFwdTransformer = TransformConvFwdToGemm; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; template static auto - MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd static auto - MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd - static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd( dummy_conv_to_gemm_transformer))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t( @@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd{}; static constexpr auto I3 = Number<3>{}; - using GemmToConvFwdTransformer = TransformConvFwdToGemm{MPerBlock, NPerBlock, KPerBlock}; template - static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } template - static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -351,7 +351,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } template - static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -364,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Shape of Ds and E must be aligned. Strides can be different. // Pass e_g_n_k_wos_lengths for logical broadcast. - static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { @@ -376,7 +376,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } // desc for problem definition - constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using AGridDesc_M_K = remove_cvref_t(dummy_conv_to_gemm_transformer))>; using BGridDesc_N_K = @@ -595,7 +595,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle compute_ptr_offset_of_n_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; - GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, @@ -674,7 +674,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // tensor descriptors for problem definiton index_t num_group_; - GemmToConvFwdTransformer conv_to_gemm_transformer_; + ConvToGemmFwdTransformer conv_to_gemm_transformer_; index_t conv_N_per_block_; @@ -1129,11 +1129,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle cde_element_op}; } + static auto + MakeArgument(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + static auto MakeInvoker() { return Invoker{}; } std::unique_ptr MakeArgumentPointer( - APointers p_a, - BPointers p_b, + APointers p_as, + BPointers p_bs, const std::array& p_ds, void* p_e, const std::array& a_g_n_c_wis_lengths, @@ -1152,8 +1225,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle const BElementwiseOperation& b_element_op, const CDEElementwiseOperation& cde_element_op) override { - return std::make_unique(p_a, - p_b, + return std::make_unique(p_as, + p_bs, p_ds, p_e, a_g_n_c_wis_lengths, @@ -1173,6 +1246,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle cde_element_op); } + std::unique_ptr + MakeArgumentPointer(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + std::unique_ptr MakeInvokerPointer() override { return std::make_unique(Invoker{}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 025123a880..6f45f81603 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using GemmToConvFwdTransformer = TransformConvFwdToGemm static auto - MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = @@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 template static auto - MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } template - static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = @@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } // desc for problem definition - constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using EGridDesc_M_N = remove_cvref_t(dummy_conv_to_gemm_transformer))>; @@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // tensor descriptors for problem definiton index_t num_group_; - GemmToConvFwdTransformer conv_to_gemm_transformer_; + ConvToGemmFwdTransformer conv_to_gemm_transformer_; index_t conv_N_per_block_; @@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return false; } + // Gridwise gemm v3 doesn't verify descriptors size + if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB()) + { + return false; + } + // check Gridwise GEMM const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); @@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 cde_element_op}; } + static auto + MakeArgument(const void* p_as, + const void* p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + static auto MakeInvoker() { return Invoker{}; } std::unique_ptr MakeArgumentPointer( @@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 cde_element_op); } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + std::unique_ptr MakeInvokerPointer() override { return std::make_unique(Invoker{}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index 4a8cb2d592..58f1396710 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using GemmToConvFwdTransformer = TransformConvFwdToGemm; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template - static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } template - static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } template - static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); } - constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using AGridDesc_M_K = remove_cvref_t(dummy_conv_to_gemm_transformer))>; using BGridDesc_N_K = @@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // D batch stride compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; - GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, @@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle EDataType* p_e_grid_; typename GridwiseGemm::RsGridPointer p_rs_grid_; - GemmToConvFwdTransformer conv_to_gemm_transformer_; + ConvToGemmFwdTransformer conv_to_gemm_transformer_; // tensor descriptors for problem definiton AGridDesc_M_K a_grid_desc_m_k_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 981f9f421b..1c9f6b0094 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1); - using GemmToConvFwdTransformer = TransformConvFwdToGemm; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template - static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } template - static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } template - static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle return out_gemmm_gemmn_desc; } - static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer) + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { @@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } // desc for problem definition - constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer; + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor(dummy_conv_to_gemm_transformer)); using BGridDesc = @@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle [&](auto i) { using DLayout = remove_cvref_t>; - GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, @@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // tensor descriptors for problem definiton index_t num_group_; - GemmToConvFwdTransformer conv_to_gemm_transformer_; + ConvToGemmFwdTransformer conv_to_gemm_transformer_; DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; @@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle cde_element_op}; } + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + static auto MakeInvoker() { return Invoker{}; } std::unique_ptr MakeArgumentPointer( @@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle cde_element_op); } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + std::unique_ptr MakeInvokerPointer() override { return std::make_unique(Invoker{}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp new file mode 100644 index 0000000000..845751c51c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -0,0 +1,1054 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle( + Array gemm_desc_kernel_args, + const index_t gemms_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ && + block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, + gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, + Tuple<>{}, + gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + Tuple<>{}, + gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_kernel_args[group_id].block_2_etile_map_); +#else + ignore = gemm_desc_kernel_args; + ignore = gemms_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; +#endif +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t MaxGemmsNum = 32; + static_assert(NumDTensor == 0, "MultiD not supported."); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ConvToGemmFwdTransformerIndexT = TransformConvFwdToGemm; + + using ConvToGemmFwdTransformerLongIndexT = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto + MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto + MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + static auto + GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base, + const ADataType* a_grid_ptr_base, + EDataType* c_grid_ptr_base) + { + // Max number of splits + // We need to use it to avoid infinity loop + constexpr index_t max_split_numbers = MaxGemmsNum / 2; + // Arrays to store transformers with smaller descs than 2GB + Array conv_to_gemm_transformers_arr; + Array a_grid_ptrs_arr; + Array c_grid_ptrs_arr; + // Queue for spliting + std::queue conv_to_gemm_transformers_queue( + {conv_to_gemm_transformer_base}); + std::queue a_grid_ptrs_queue({a_grid_ptr_base}); + std::queue c_grid_ptrs_queue({c_grid_ptr_base}); + + index_t gemms_number = 0; + index_t split_numbers = 0; + // Algorithm: + // While queue is not empty: + // 1. Get transformer from queue. + // 2. If descs are smaller than 2GB push to result array. + // 3. If descs are bigger than 2GB split into left and right transformer. + // and push the both into the queue. + while(!conv_to_gemm_transformers_queue.empty() && split_numbers < max_split_numbers && + gemms_number < MaxGemmsNum) + { + // Get transformer from the queue + const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front(); + const ADataType* a_grid_ptr = a_grid_ptrs_queue.front(); + EDataType* c_grid_ptr = c_grid_ptrs_queue.front(); + + // Check if convolution not exceed 2GB + if(conv_to_gemm_transformer.AreDescriptorsSmallerThan2GB()) + { + // If yes, push into result array + conv_to_gemm_transformers_arr(gemms_number) = + ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer}; + a_grid_ptrs_arr(gemms_number) = a_grid_ptr; + c_grid_ptrs_arr(gemms_number) = c_grid_ptr; + gemms_number++; + } + else + { + // If no, split into left and right convolutions + ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part, + conv_to_gemm_transformers_right_part; + const ADataType* a_grid_right_ptr; + EDataType* c_grid_right_ptr; + + ck::tie(conv_to_gemm_transformers_left_part, + conv_to_gemm_transformers_right_part, + a_grid_right_ptr, + c_grid_right_ptr) = + conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr); + + conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part); + conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part); + // Left offsets remain the same + a_grid_ptrs_queue.push(a_grid_ptr); + a_grid_ptrs_queue.push(a_grid_right_ptr); + c_grid_ptrs_queue.push(c_grid_ptr); + c_grid_ptrs_queue.push(c_grid_right_ptr); + split_numbers++; + } + // Remove from the queue + conv_to_gemm_transformers_queue.pop(); + a_grid_ptrs_queue.pop(); + c_grid_ptrs_queue.pop(); + } + + const bool is_split_valid = conv_to_gemm_transformers_queue.empty(); + + return ck::make_tuple(conv_to_gemm_transformers_arr, + a_grid_ptrs_arr, + c_grid_ptrs_arr, + gemms_number, + is_split_valid); + } + +#define GridwiseGemmTemplateParameters \ + ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + AComputeDataType + // Use appropriate gridwise gemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + // Structure for each gemm(conv) + struct GemmArgs + { + // pointers + const ADataType* a_ptr_; + const BDataType* b_ptr_; + EDataType* e_ptr_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& /*p_ds*/, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + /*ds_g_n_k_wos_lengths*/, + const std::array, NumDTensor>& + /*ds_g_n_k_wos_strides*/, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : num_group_{static_cast(a_g_n_c_wis_lengths[0])}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // Perform grouped gemm, generate array of tranformer for convolution + Array conv_to_gemm_transformer_arr; + Array a_grid_ptrs; + Array c_grid_ptrs; + + ck::tie(conv_to_gemm_transformer_arr, + a_grid_ptrs, + c_grid_ptrs, + gemms_count_, + is_split_valid_) = + GenerateConvToGemmTransforms( + ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + static_cast(p_a), + static_cast(p_e)); + + grid_size_ = 0; + valid_gemms_count_ = 0; + + if(is_split_valid_) + { + // Create GemmArg for each gemm(conv) + for(index_t i = 0; i < gemms_count_; i++) + { + const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K( + conv_to_gemm_transformer_arr[i])}; + const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K( + conv_to_gemm_transformer_arr[i])}; + const auto e_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_arr[i]); + + const auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + const index_t grid_size_grp = + block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + Tuple<>{}, + e_grid_desc_m_n, + block_2_etile_map)) + { + + gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ + a_grid_ptrs[i], + static_cast(p_b), + c_grid_ptrs[i], + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd}; + + valid_gemms_count_++; + } + } + // N is the same for all convs + conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); + } + + // Strides for G and N remain the same + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; + } + + void Print() const + { + for(index_t i = 0; i < valid_gemms_count_; i++) + { + std::cout << "A[AK0, M, AK1]: " << gemm_desc_kernel_args_[i].a_grid_desc_ak0_m_ak1_ + << std::endl; + std::cout << "B[BK0, N, BK1]: " << gemm_desc_kernel_args_[i].b_grid_desc_bk0_n_bk1_ + << std::endl; + std::cout + << "E[MBlock, MPerBlock, NBlock, NPerBlock]: " + << gemm_desc_kernel_args_[i].e_grid_desc_mblock_mperblock_nblock_nperblock_ + << std::endl; + } + } + + index_t num_group_; + index_t conv_N_per_block_; + + Array gemm_desc_kernel_args_; + + index_t grid_size_; + index_t gemms_count_; + index_t valid_gemms_count_; + + bool is_split_valid_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + const index_t gdx = arg.grid_size_; + const index_t gdy = arg.num_group_; + const index_t gdz = num_workgroups_per_Conv_N; + + // K is constant for all gemms + const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< + GridwiseGemm, + MaxGemmsNum, + GemmArgs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.gemm_desc_kernel_args_, + arg.gemms_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + const long_index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const long_index_t C = arg.b_g_k_c_xs_lengths_[I2]; + + // Check if all descs are valid + if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_)) + { + return false; + } + // check device + if(get_device_name() == "gfx908") + { + // FIXME: re-enable fp64 when SWDEV-335738 is fixed + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + if(!ck::is_xdl_supported()) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC()) + { + return false; + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + // Check access per C + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i64; + std::array a_g_n_c_wis_strides_i64; + std::array b_g_k_c_xs_lengths_i64; + std::array b_g_k_c_xs_strides_i64; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i64; + std::array, NumDTensor> ds_g_n_k_wos_strides_i64; + std::array e_g_n_k_wos_lengths_i64; + std::array e_g_n_k_wos_strides_i64; + std::array conv_filter_strides_i64; + std::array conv_filter_dilations_i64; + std::array input_left_pads_i64; + std::array input_right_pads_i64; + + array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i64, conv_filter_strides); + array_convert(conv_filter_dilations_i64, conv_filter_dilations); + array_convert(input_left_pads_i64, input_left_pads); + array_convert(input_right_pads_i64, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i64, + a_g_n_c_wis_strides_i64, + b_g_k_c_xs_lengths_i64, + b_g_k_c_xs_strides_i64, + ds_g_n_k_wos_lengths_i64, + ds_g_n_k_wos_strides_i64, + e_g_n_k_wos_lengths_i64, + e_g_n_k_wos_strides_i64, + conv_filter_strides_i64, + conv_filter_dilations_i64, + input_left_pads_i64, + input_right_pads_i64, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + + std::array a_g_n_c_wis_lengths_i64; + std::array a_g_n_c_wis_strides_i64; + std::array b_g_k_c_xs_lengths_i64; + std::array b_g_k_c_xs_strides_i64; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i64; + std::array, NumDTensor> ds_g_n_k_wos_strides_i64; + std::array e_g_n_k_wos_lengths_i64; + std::array e_g_n_k_wos_strides_i64; + std::array conv_filter_strides_i64; + std::array conv_filter_dilations_i64; + std::array input_left_pads_i64; + std::array input_right_pads_i64; + + array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i64, conv_filter_strides); + array_convert(conv_filter_dilations_i64, conv_filter_dilations); + array_convert(input_left_pads_i64, input_left_pads); + array_convert(input_right_pads_i64, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i64, + a_g_n_c_wis_strides_i64, + b_g_k_c_xs_lengths_i64, + b_g_k_c_xs_strides_i64, + ds_g_n_k_wos_lengths_i64, + ds_g_n_k_wos_strides_i64, + e_g_n_k_wos_lengths_i64, + e_g_n_k_wos_strides_i64, + conv_filter_strides_i64, + conv_filter_dilations_i64, + input_left_pads_i64, + input_right_pads_i64, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index 4828beb3a2..648736fcbf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - using GemmToConvFwdTransformer = + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = @@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl b_g_k_c_xs_lengths[I2] = C; c_g_n_k_wos_lengths[I1] = N; - GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, + ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, image_g_n_c_wis_strides, b_g_k_c_xs_lengths, {}, // not needed for A Descriptor diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp index 27f48a84ba..5f6f2768eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDesc_M_N& c_grid_desc_m_n) { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB)) + { + return false; + } + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 1da7236978..562b9b8ffa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3 const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1, const CGridDesc_M_N& c_grid_desc_m_n) { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB)) + { + return false; + } + const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2); const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2); const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 07cb7a7310..b91b12ad52 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -19,7 +19,8 @@ template + index_t NumGroupsToMerge = 1, + typename IndexType = index_t> struct TransformConvFwdToGemm { private: @@ -46,10 +47,10 @@ struct TransformConvFwdToGemm } template - static index_t GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, - const ConvDimsType& a_g_n_c_wis_strides, - const ConvDimsType& c_g_n_k_wos_lengths, - const ConvDimsType& c_g_n_k_wos_strides) + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) { const long_index_t a_element_space_size = calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); @@ -59,7 +60,7 @@ struct TransformConvFwdToGemm c_element_space_size * sizeof(CDataType)); constexpr long_index_t TwoGB = (long_index_t{1} << 31); - const index_t N = a_g_n_c_wis_lengths[I1]; + const IndexType N = a_g_n_c_wis_lengths[I1]; if(element_space_size > TwoGB) { @@ -70,7 +71,7 @@ struct TransformConvFwdToGemm { // Find least divisor of N larger than element_space_size / TwoGB // Iterate up to sqrt(N). There are no divisors above this value. - for(index_t least_divisor = divisor; least_divisor * least_divisor <= N; + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; least_divisor++) { if(N % least_divisor == 0) @@ -98,6 +99,53 @@ struct TransformConvFwdToGemm public: __host__ __device__ constexpr TransformConvFwdToGemm() {} + template + __host__ __device__ + TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base) + : N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_fwd_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_fwd_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_fwd_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_fwd_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_fwd_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_fwd_to_gemm_base.X_)}, + K_{static_cast(transform_conv_fwd_to_gemm_base.K_)}, + C_{static_cast(transform_conv_fwd_to_gemm_base.C_)}, + DiStride_{static_cast(transform_conv_fwd_to_gemm_base.DiStride_)}, + HiStride_{static_cast(transform_conv_fwd_to_gemm_base.HiStride_)}, + WiStride_{static_cast(transform_conv_fwd_to_gemm_base.WiStride_)}, + DoStride_{static_cast(transform_conv_fwd_to_gemm_base.DoStride_)}, + HoStride_{static_cast(transform_conv_fwd_to_gemm_base.HoStride_)}, + WoStride_{static_cast(transform_conv_fwd_to_gemm_base.WoStride_)}, + XStride_{static_cast(transform_conv_fwd_to_gemm_base.XStride_)}, + CStrideTensorA_{static_cast(transform_conv_fwd_to_gemm_base.CStrideTensorA_)}, + CStrideTensorB_{static_cast(transform_conv_fwd_to_gemm_base.CStrideTensorB_)}, + KStrideTensorB_{static_cast(transform_conv_fwd_to_gemm_base.KStrideTensorB_)}, + KStrideTensorC_{static_cast(transform_conv_fwd_to_gemm_base.KStrideTensorC_)}, + NStrideTensorA_{static_cast(transform_conv_fwd_to_gemm_base.NStrideTensorA_)}, + NStrideTensorC_{static_cast(transform_conv_fwd_to_gemm_base.NStrideTensorC_)}, + GStrideTensorA_{static_cast(transform_conv_fwd_to_gemm_base.GStrideTensorA_)}, + GStrideTensorB_{static_cast(transform_conv_fwd_to_gemm_base.GStrideTensorB_)}, + GStrideTensorC_{static_cast(transform_conv_fwd_to_gemm_base.GStrideTensorC_)}, + ConvStrideD_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadW_)}, + ZYX_{static_cast(transform_conv_fwd_to_gemm_base.ZYX_)} + { + } + template > || - is_same_v>); - static_assert(is_same_v> || - is_same_v>); + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); if constexpr(SplitN) { @@ -164,7 +215,6 @@ struct TransformConvFwdToGemm { N_ = c_g_n_k_wos_lengths[I1]; } - NDoHoWo_ = N_ * Wo_; } template > || - is_same_v>); - static_assert(is_same_v> || - is_same_v>); + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); if constexpr(SplitN) { @@ -233,7 +286,6 @@ struct TransformConvFwdToGemm { N_ = c_g_n_k_wos_lengths[I1]; } - NDoHoWo_ = N_ * Ho_ * Wo_; } template > || - is_same_v>); - static_assert(is_same_v> || - is_same_v>); + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); if constexpr(SplitN) { @@ -302,7 +357,122 @@ struct TransformConvFwdToGemm { N_ = c_g_n_k_wos_lengths[I1]; } - NDoHoWo_ = N_ * Do_ * Ho_ * Wo_; + } + + __host__ bool AreDescriptorsSmallerThan2GB() const + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_; + + bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + c_right_offset = (Do_ / 2) * DoStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; + + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + c_right_offset = (Ho_ / 2) * HoStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; + + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + c_right_offset = (Wo_ / 2) * WoStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); } // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as @@ -320,20 +490,27 @@ struct TransformConvFwdToGemm { if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_), - make_tuple(WiStride_, CStrideTensorA_)); + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); } else { const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( - make_tuple(NDoHoWo_, NumGroupsToMerge, C_), - make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_)); + make_tuple(N_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); return transform_tensor_descriptor( in_gemmm_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)), + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), make_pass_through_transform(C_)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } } @@ -527,20 +704,29 @@ struct TransformConvFwdToGemm { if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_), - make_tuple(WiStride_, CStrideTensorA_)); + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); } else { const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( - make_tuple(NDoHoWo_, NumGroupsToMerge, C_), - make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_)); + make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); return transform_tensor_descriptor( in_gemmm_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)), + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), make_pass_through_transform(C_)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } } @@ -759,20 +945,34 @@ struct TransformConvFwdToGemm { if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_), - make_tuple(WiStride_, CStrideTensorA_)); + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); } else { const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( - make_tuple(NDoHoWo_, NumGroupsToMerge, C_), - make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_)); + make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); return transform_tensor_descriptor( in_gemmm_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } } @@ -1119,45 +1319,70 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v, + index_t NDimSp = NDimSpatial, + + typename std::enable_if), bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { - return make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo_, K_)); + return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); } - template < - typename CLayout, + template || - is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v, - bool>::type = false> + typename std::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } + + template ), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + const IndexType NDoHoWo = N_ * Wo_; if constexpr(NumGroupsToMerge == 1) { - return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), make_tuple(WoStride_, KStrideTensorC_)); } else { const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor( - make_tuple(NDoHoWo_, NumGroupsToMerge, K_, 1), - make_tuple(WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_)); + make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple( + NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_)); // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( nhwo_groups_k_1_desc, - make_tuple(make_pass_through_transform(NDoHoWo_), + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(K_), make_pad_transform(1, 0, NumGroupsToMerge - 1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // We need only matrices from diagonal. X_or returns 0 for the same // values. So if matrices is not on diagonal then it will be stored in padding. @@ -1167,7 +1392,7 @@ struct TransformConvFwdToGemm NumGroupsToMerge == 32 || NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_pass_through_transform(NDoHoWo_), + make_tuple(make_pass_through_transform(NDoHoWo), make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), make_pass_through_transform(K_)), make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), @@ -1175,45 +1400,146 @@ struct TransformConvFwdToGemm // Merge To M, N return transform_tensor_descriptor( unmerged_padded_desc, - make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)), + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), make_merge_transform(make_tuple(K_, NumGroupsToMerge))), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } } - // for output bias template , - bool>::type = false> + index_t NDimSp = NDimSpatial, + + typename std::enable_if< + NDimSp == 2 && (is_same_v || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { - const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), make_tuple(I0, KStrideTensorC_)); - - return out_gemmm_gemmn_desc; + const IndexType NDoHoWo = N_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } - public: - index_t N_; + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { - private: - const index_t Di_, Hi_, Wi_; - const index_t Do_, Ho_, Wo_; - const index_t Z_, Y_, X_; - const index_t K_, C_; - const index_t DiStride_, HiStride_, WiStride_; - const index_t WoStride_; - const index_t XStride_; - const index_t CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_; - const index_t NStrideTensorA_; - const index_t GStrideTensorA_, GStrideTensorB_, GStrideTensorC_; - const index_t ConvStrideD_, ConvStrideH_, ConvStrideW_; - const index_t ConvDilationD_, ConvDilationH_, ConvDilationW_; - const index_t InLeftPadD_, InLeftPadH_, InLeftPadW_; - const index_t InRightPadD_, InRightPadH_, InRightPadW_; - const index_t ZYX_; - index_t NDoHoWo_; + const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + DoStride_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + IndexType N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType DiStride_, HiStride_, WiStride_; + IndexType DoStride_, HoStride_, WoStride_; + IndexType XStride_; + IndexType CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_; + IndexType NStrideTensorA_, NStrideTensorC_; + IndexType GStrideTensorA_, GStrideTensorB_, GStrideTensorC_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType ZYX_; }; // wrapper class to call member functions on TransformConvToGemm struct at runtime @@ -1230,17 +1556,17 @@ struct TransformConv if(NDimSpatial == 2) { return conv_fwd_to_gemm - .template MakeCDescriptor_M_N(); + .template MakeCDescriptor_M_N(); } else if(NDimSpatial == 3) { return conv_fwd_to_gemm - .template MakeCDescriptor_M_N(); + .template MakeCDescriptor_M_N(); } else if(NDimSpatial == 1) { return conv_fwd_to_gemm - .template MakeCDescriptor_M_N(); + .template MakeCDescriptor_M_N(); } } }; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 382b9c5551..87fa9aa38a 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/data_type.hpp" #include "ck/utility/f8_utils.hpp" #include "ck/utility/random_gen.hpp" +#include "ck/utility/array.hpp" namespace ck { // Define the common macro for gfx94x models @@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert(bf8_t x) #endif } +template +inline __host__ __device__ void array_convert(std::array& y, + const std::array& x) +{ + for(std::size_t i = 0; i < NumElems; i++) + { + y[i] = type_convert(x[i]); + } +} + +template +inline __host__ __device__ void array_convert(Array& y, const Array& x) +{ + for(std::size_t i = 0; i < NumElems; i++) + { + y[i] = type_convert(x[i]); + } +} + // Declare a template function for bf16 conversion using RTN template __host__ __device__ constexpr Y bf16_convert_rtn(X x); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp index 5f2ab12164..51379b0944 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator public: Argument(const Tensor& input, Tensor& output, - std::vector filter_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) + std::vector filter_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) : input_{input}, output_{output}, conv_strides_{conv_filter_strides}, @@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator const Tensor& input_; Tensor& output_; - std::vector conv_strides_; - std::vector conv_dilations_; - std::vector in_left_pads_; - std::vector in_right_pads_; + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; - std::vector filter_spatial_lengths_; - std::vector output_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; private: void initOutputSpatialLengths() { constexpr auto input_offset_to_spatial = 3; - for(ck::index_t i = 0; i < NDimSpatial; ++i) + for(ck::long_index_t i = 0; i < NDimSpatial; ++i) { // XEff = (X - 1) * conv_dilation_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; + const ck::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; output_spatial_lengths_.push_back( (output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] + @@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator throw std::runtime_error("wrong! inconsistent dimension"); } - const index_t G = arg.output_.GetLengths()[0]; - const index_t N = arg.output_.GetLengths()[1]; - const index_t C = arg.output_.GetLengths()[2]; + const long_index_t G = arg.output_.GetLengths()[0]; + const long_index_t N = arg.output_.GetLengths()[1]; + const long_index_t C = arg.output_.GetLengths()[2]; if constexpr(NDimSpatial == 1) { - const index_t Wo = arg.output_spatial_lengths_[0]; - auto func = [&](auto g, auto n) { - for(index_t wo = 0; wo < Wo; ++wo) + const long_index_t Wo = arg.output_spatial_lengths_[0]; + auto func = [&](auto g, auto n) { + for(long_index_t wo = 0; wo < Wo; ++wo) { - index_t row = n * Wo + wo; - index_t column = 0; + long_index_t row = n * Wo + wo; + long_index_t column = 0; - for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) + for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) { auto wi = static_cast(wo * arg.conv_strides_[0]) + static_cast(x * arg.conv_dilations_[0]) - static_cast(arg.in_left_pads_[0]); - for(index_t c = 0; c < C; ++c) + for(long_index_t c = 0; c < C; ++c) { if(wi >= 0 && ck::type_convert(wi) < arg.output_.GetLengths()[3]) @@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator } else if constexpr(NDimSpatial == 2) { - const index_t Ho = arg.output_spatial_lengths_[0]; - const index_t Wo = arg.output_spatial_lengths_[1]; + const long_index_t Ho = arg.output_spatial_lengths_[0]; + const long_index_t Wo = arg.output_spatial_lengths_[1]; auto func = [&](auto g, auto n) { - for(index_t ho = 0; ho < Ho; ++ho) + for(long_index_t ho = 0; ho < Ho; ++ho) { - for(index_t wo = 0; wo < Wo; ++wo) + for(long_index_t wo = 0; wo < Wo; ++wo) { - index_t row = n * Ho * Wo + ho * Wo + wo; - index_t column = 0; + long_index_t row = n * Ho * Wo + ho * Wo + wo; + long_index_t column = 0; - for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) + for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) { auto hi = static_cast(ho * arg.conv_strides_[0]) + static_cast(y * arg.conv_dilations_[0]) - static_cast(arg.in_left_pads_[0]); - for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) + for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) { auto wi = static_cast(wo * arg.conv_strides_[1]) + static_cast(x * arg.conv_dilations_[1]) - static_cast(arg.in_left_pads_[1]); - for(index_t c = 0; c < C; ++c) + for(long_index_t c = 0; c < C; ++c) { if(hi >= 0 && @@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator } else if constexpr(NDimSpatial == 3) { - const index_t Do = arg.output_spatial_lengths_[0]; - const index_t Ho = arg.output_spatial_lengths_[1]; - const index_t Wo = arg.output_spatial_lengths_[2]; + const long_index_t Do = arg.output_spatial_lengths_[0]; + const long_index_t Ho = arg.output_spatial_lengths_[1]; + const long_index_t Wo = arg.output_spatial_lengths_[2]; auto func = [&](auto g, auto n) { - for(index_t d_o = 0; d_o < Do; ++d_o) + for(long_index_t d_o = 0; d_o < Do; ++d_o) { - for(index_t ho = 0; ho < Ho; ++ho) + for(long_index_t ho = 0; ho < Ho; ++ho) { - for(index_t wo = 0; wo < Wo; ++wo) + for(long_index_t wo = 0; wo < Wo; ++wo) { - index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; - index_t column = 0; + long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; + long_index_t column = 0; - for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) + for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) { auto di = static_cast(d_o * arg.conv_strides_[0]) + static_cast(z * arg.conv_dilations_[0]) - static_cast(arg.in_left_pads_[0]); - for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) + for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) { auto hi = static_cast(ho * @@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator static_cast(y * arg.conv_dilations_[1]) - static_cast(arg.in_left_pads_[1]); - for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x) + for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2]; + ++x) { auto wi = static_cast( @@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator static_cast( x * arg.conv_dilations_[2]) - static_cast(arg.in_left_pads_[2]); - for(index_t c = 0; c < C; ++c) + for(long_index_t c = 0; c < C; ++c) { if(di >= 0 && ck::type_convert(di) < @@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator bool IsSupportedArgument(const Argument& arg) { - const ck::index_t G = arg.output_.GetLengths()[0]; - const ck::index_t N = arg.output_.GetLengths()[1]; - const ck::index_t C = arg.output_.GetLengths()[2]; + const ck::long_index_t G = arg.output_.GetLengths()[0]; + const ck::long_index_t N = arg.output_.GetLengths()[1]; + const ck::long_index_t C = arg.output_.GetLengths()[2]; - const index_t NDoHoWo = - N * ck::accumulate_n( + const long_index_t NDoHoWo = + N * ck::accumulate_n( arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); - const index_t CZYX = - C * ck::accumulate_n( + const long_index_t CZYX = + C * ck::accumulate_n( arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); if(!(arg.input_.GetLengths()[0] == static_cast(G) && @@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator static auto MakeArgument(const Tensor& input, Tensor& output, - std::vector filter_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) + std::vector filter_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) { return Argument{input, output, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index a41f952408..10b169c21e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator Tensor& input, const Tensor& weight, const Tensor& output, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, @@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator const std::array, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array, NumDElementwiseTensor>& elementwise_d_tensors_; - std::vector conv_strides_; - std::vector conv_dilations_; - std::vector in_left_pads_; - std::vector in_right_pads_; + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; @@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator Tensor& input, const Tensor& weight, const Tensor& output, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp index a8f2ce1713..b6e889a325 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp @@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator const Tensor& in_n_c_hi_wi, Tensor& wei_k_c_y_x, const Tensor& out_n_k_ho_wo, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, @@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator const std::array, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array, NumDElementwiseTensor>& elementwise_d_tensors_; - std::vector conv_strides_; - std::vector conv_dilations_; - std::vector in_left_pads_; - std::vector in_right_pads_; + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; @@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator const Tensor& in_n_c_hi_wi, Tensor& wei_k_c_y_x, const Tensor& out_n_k_ho_wo, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index d63b5256f9..9c1349f56c 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator const Tensor& input, const Tensor& weight, Tensor& output, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, @@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator const std::array, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array, NumDElementwiseTensor>& elementwise_d_tensors_; - std::vector conv_strides_; - std::vector conv_dilations_; - std::vector in_left_pads_; - std::vector in_right_pads_; + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; @@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator const Tensor& input, const Tensor& weight, Tensor& output, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp index 4682c5c223..da16295da3 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp @@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator public: Argument(const Tensor& input, Tensor& output, - std::vector filter_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) + std::vector filter_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) : input_{input}, output_{output}, conv_strides_{conv_filter_strides}, @@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator const Tensor& input_; Tensor& output_; - std::vector conv_strides_; - std::vector conv_dilations_; - std::vector in_left_pads_; - std::vector in_right_pads_; + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; - std::vector filter_spatial_lengths_; - std::vector output_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; private: void initOutputSpatialLengths() @@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator { // XEff = (X - 1) * conv_dilation_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; + const ck::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; output_spatial_lengths_.push_back( (input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] + @@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator throw std::runtime_error("wrong! inconsistent dimension"); } - const index_t G = arg.input_.GetLengths()[0]; - const index_t N = arg.input_.GetLengths()[1]; - const index_t C = arg.input_.GetLengths()[2]; + const long_index_t G = arg.input_.GetLengths()[0]; + const long_index_t N = arg.input_.GetLengths()[1]; + const long_index_t C = arg.input_.GetLengths()[2]; if constexpr(NDimSpatial == 1) { - const index_t Wo = arg.output_spatial_lengths_[0]; - auto func = [&](auto g, auto n, auto wo) { - index_t row = n * Wo + wo; - index_t column = 0; + const long_index_t Wo = arg.output_spatial_lengths_[0]; + auto func = [&](auto g, auto n, auto wo) { + long_index_t row = n * Wo + wo; + long_index_t column = 0; - for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) + for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) { auto wi = static_cast(wo * arg.conv_strides_[0]) + static_cast(x * arg.conv_dilations_[0]) - static_cast(arg.in_left_pads_[0]); - for(index_t c = 0; c < C; ++c) + for(long_index_t c = 0; c < C; ++c) { if(wi >= 0 && ck::type_convert(wi) < arg.input_.GetLengths()[3]) @@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator } else if constexpr(NDimSpatial == 2) { - const index_t Ho = arg.output_spatial_lengths_[0]; - const index_t Wo = arg.output_spatial_lengths_[1]; + const long_index_t Ho = arg.output_spatial_lengths_[0]; + const long_index_t Wo = arg.output_spatial_lengths_[1]; auto func = [&](auto g, auto n, auto ho, auto wo) { - index_t row = n * Ho * Wo + ho * Wo + wo; - index_t column = 0; + long_index_t row = n * Ho * Wo + ho * Wo + wo; + long_index_t column = 0; - for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) + for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) { auto hi = static_cast(ho * arg.conv_strides_[0]) + static_cast(y * arg.conv_dilations_[0]) - static_cast(arg.in_left_pads_[0]); - for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) + for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) { auto wi = static_cast(wo * arg.conv_strides_[1]) + static_cast(x * arg.conv_dilations_[1]) - static_cast(arg.in_left_pads_[1]); - for(index_t c = 0; c < C; ++c) + for(long_index_t c = 0; c < C; ++c) { if(hi >= 0 && @@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator } else if constexpr(NDimSpatial == 3) { - const index_t Do = arg.output_spatial_lengths_[0]; - const index_t Ho = arg.output_spatial_lengths_[1]; - const index_t Wo = arg.output_spatial_lengths_[2]; + const long_index_t Do = arg.output_spatial_lengths_[0]; + const long_index_t Ho = arg.output_spatial_lengths_[1]; + const long_index_t Wo = arg.output_spatial_lengths_[2]; auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) { - index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; - index_t column = 0; + long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; + long_index_t column = 0; - for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) + for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) { auto di = static_cast(d_o * arg.conv_strides_[0]) + static_cast(z * arg.conv_dilations_[0]) - static_cast(arg.in_left_pads_[0]); - for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) + for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) { auto hi = static_cast(ho * arg.conv_strides_[1]) + static_cast(y * arg.conv_dilations_[1]) - static_cast(arg.in_left_pads_[1]); - for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x) + for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x) { auto wi = static_cast(wo * arg.conv_strides_[2]) + static_cast(x * arg.conv_dilations_[2]) - static_cast(arg.in_left_pads_[2]); - for(index_t c = 0; c < C; ++c) + for(long_index_t c = 0; c < C; ++c) { if(di >= 0 && ck::type_convert(di) < @@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator bool IsSupportedArgument(const Argument& arg) { - const ck::index_t G = arg.input_.GetLengths()[0]; - const ck::index_t N = arg.input_.GetLengths()[1]; - const ck::index_t C = arg.input_.GetLengths()[2]; + const ck::long_index_t G = arg.input_.GetLengths()[0]; + const ck::long_index_t N = arg.input_.GetLengths()[1]; + const ck::long_index_t C = arg.input_.GetLengths()[2]; - const index_t NDoHoWo = - N * ck::accumulate_n( + const long_index_t NDoHoWo = + N * ck::accumulate_n( arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); - const index_t CZYX = - C * ck::accumulate_n( + const long_index_t CZYX = + C * ck::accumulate_n( arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); if(!(arg.output_.GetLengths()[0] == static_cast(G) && @@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator static auto MakeArgument(const Tensor& input, Tensor& output, - std::vector filter_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) + std::vector filter_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) { return Argument{input, output, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp new file mode 100644 index 0000000000..05cb8d5d05 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_large_tensor_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_large_tensor_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 0233d6d85c..4e117f86aa 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -17,6 +17,7 @@ #endif #ifdef CK_USE_XDL #include "grouped_convolution_forward_xdl.inc" +#include "grouped_convolution_forward_xdl_large_tensor.inc" #include "grouped_convolution_forward_xdl_merged_groups.inc" #include "grouped_convolution_forward_comp_xdl.inc" #include "grouped_convolution_forward_mem_inter_xdl.inc" @@ -200,6 +201,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); @@ -215,6 +218,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs); @@ -232,6 +237,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs); @@ -291,6 +298,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); @@ -347,6 +356,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs); @@ -364,6 +375,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc new file mode 100644 index 0000000000..6a2c61d058 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/utility/convolution_parameter.hpp b/library/include/ck/library/utility/convolution_parameter.hpp index df6efca108..70d581a67e 100644 --- a/library/include/ck/library/utility/convolution_parameter.hpp +++ b/library/include/ck/library/utility/convolution_parameter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,23 +31,35 @@ struct ConvParam const std::vector& left_pads, const std::vector& right_pads); - ck::index_t num_dim_spatial_; - ck::index_t G_; - ck::index_t N_; - ck::index_t K_; - ck::index_t C_; + ConvParam(ck::long_index_t n_dim, + ck::long_index_t group_count, + ck::long_index_t n_batch, + ck::long_index_t n_out_channels, + ck::long_index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads); - std::vector filter_spatial_lengths_; - std::vector input_spatial_lengths_; - std::vector output_spatial_lengths_; + ck::long_index_t num_dim_spatial_; + ck::long_index_t G_; + ck::long_index_t N_; + ck::long_index_t K_; + ck::long_index_t C_; - std::vector conv_filter_strides_; - std::vector conv_filter_dilations_; + std::vector filter_spatial_lengths_; + std::vector input_spatial_lengths_; + std::vector output_spatial_lengths_; - std::vector input_left_pads_; - std::vector input_right_pads_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; - std::vector GetOutputSpatialLengths() const; + std::vector input_left_pads_; + std::vector input_right_pads_; + + std::vector GetOutputSpatialLengths() const; std::size_t GetFlops() const; diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index 493b992aca..a58acaf116 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -96,9 +96,16 @@ struct HostTensorDescriptor this->CalculateStrides(); } + HostTensorDescriptor(const std::initializer_list& lens) + : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + template , std::size_t>>> + std::is_convertible_v, std::size_t> || + std::is_convertible_v, ck::long_index_t>>> HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); @@ -114,11 +121,19 @@ struct HostTensorDescriptor { } + HostTensorDescriptor(const std::initializer_list& lens, + const std::initializer_list& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + template , std::size_t> && - std::is_convertible_v, std::size_t>>> + (std::is_convertible_v, std::size_t> && + std::is_convertible_v, std::size_t>) || + (std::is_convertible_v, ck::long_index_t> && + std::is_convertible_v, ck::long_index_t>)>> HostTensorDescriptor(const Lengths& lens, const Strides& strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 170625a6a0..095e02795b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -9,6 +9,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + # large tensor + # NHWGC, GKYXC, NHWGK + xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp # merged groups # NHWGC, GKYXC, NHWGK xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..5eff1fd33b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..9139622138 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 0000000000..1d52e21c8c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 5be6672723..cd80d77d09 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -9,6 +9,10 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..62198fff67 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..3357653adf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000..d276b8e3bc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/utility/convolution_parameter.cpp b/library/src/utility/convolution_parameter.cpp index 57cedd6019..a71f8a4fa1 100644 --- a/library/src/utility/convolution_parameter.cpp +++ b/library/src/utility/convolution_parameter.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host_utility/io.hpp" @@ -20,6 +20,63 @@ ConvParam::ConvParam(ck::index_t n_dim, const std::vector& dilations, const std::vector& left_pads, const std::vector& right_pads) + : num_dim_spatial_(static_cast(n_dim)), + G_(static_cast(group_count)), + N_(static_cast(n_batch)), + K_(static_cast(n_out_channels)), + C_(static_cast(n_in_channels)), + filter_spatial_lengths_(num_dim_spatial_), + input_spatial_lengths_(num_dim_spatial_), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(num_dim_spatial_), + conv_filter_dilations_(num_dim_spatial_), + input_left_pads_(num_dim_spatial_), + input_right_pads_(num_dim_spatial_) +{ + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParam::ConvParam: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) + { + filter_spatial_lengths_[i] = static_cast(filters_len[i]); + input_spatial_lengths_[i] = static_cast(input_len[i]); + conv_filter_strides_[i] = static_cast(strides[i]); + conv_filter_dilations_[i] = static_cast(dilations[i]); + input_left_pads_[i] = static_cast(left_pads[i]); + input_right_pads_[i] = static_cast(right_pads[i]); + + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } +} + +ConvParam::ConvParam(ck::long_index_t n_dim, + ck::long_index_t group_count, + ck::long_index_t n_batch, + ck::long_index_t n_out_channels, + ck::long_index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) : num_dim_spatial_(n_dim), G_(group_count), N_(n_batch), @@ -49,7 +106,8 @@ ConvParam::ConvParam(ck::index_t n_dim, { // XEff = (X - 1) * conv_dilation_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + const ck::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; output_spatial_lengths_[i] = (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / @@ -63,7 +121,7 @@ ConvParam::ConvParam() { } -std::vector ConvParam::GetOutputSpatialLengths() const +std::vector ConvParam::GetOutputSpatialLengths() const { return output_spatial_lengths_; } @@ -97,46 +155,46 @@ std::string get_conv_param_parser_helper_msg() ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]) { - const ck::index_t G = std::stoi(argv[arg_idx++]); - const ck::index_t N = std::stoi(argv[arg_idx++]); - const ck::index_t K = std::stoi(argv[arg_idx++]); - const ck::index_t C = std::stoi(argv[arg_idx++]); + const ck::long_index_t G = std::stol(argv[arg_idx++]); + const ck::long_index_t N = std::stol(argv[arg_idx++]); + const ck::long_index_t K = std::stol(argv[arg_idx++]); + const ck::long_index_t C = std::stol(argv[arg_idx++]); - std::vector filter_spatial_lengths(num_dim_spatial); - std::vector input_spatial_lengths(num_dim_spatial); - std::vector conv_filter_strides(num_dim_spatial); - std::vector conv_filter_dilations(num_dim_spatial); - std::vector input_left_pads(num_dim_spatial); - std::vector input_right_pads(num_dim_spatial); + std::vector filter_spatial_lengths(num_dim_spatial); + std::vector input_spatial_lengths(num_dim_spatial); + std::vector conv_filter_strides(num_dim_spatial); + std::vector conv_filter_dilations(num_dim_spatial); + std::vector input_left_pads(num_dim_spatial); + std::vector input_right_pads(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + filter_spatial_lengths[i] = std::stol(argv[arg_idx++]); } for(int i = 0; i < num_dim_spatial; ++i) { - input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + input_spatial_lengths[i] = std::stol(argv[arg_idx++]); } for(int i = 0; i < num_dim_spatial; ++i) { - conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + conv_filter_strides[i] = std::stol(argv[arg_idx++]); } for(int i = 0; i < num_dim_spatial; ++i) { - conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + conv_filter_dilations[i] = std::stol(argv[arg_idx++]); } for(int i = 0; i < num_dim_spatial; ++i) { - input_left_pads[i] = std::stoi(argv[arg_idx++]); + input_left_pads[i] = std::stol(argv[arg_idx++]); } for(int i = 0; i < num_dim_spatial; ++i) { - input_right_pads[i] = std::stoi(argv[arg_idx++]); + input_right_pads[i] = std::stol(argv[arg_idx++]); } return ck::utils::conv::ConvParam{num_dim_spatial, diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index 52152a90fe..b70dd9538d 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -82,6 +82,29 @@ bool profile_conv_bwd_data_impl(int do_verification, Tensor weight(wei_g_k_c_xs_desc); Tensor output(out_g_n_k_wos_desc); + std::vector input_spatial_lengths_i32(NDimSpatial); + std::vector filter_spatial_lengths_i32(NDimSpatial); + std::vector output_spatial_lengths_i32(NDimSpatial); + std::vector conv_filter_strides_i32(NDimSpatial); + std::vector conv_filter_dilations_i32(NDimSpatial); + std::vector input_left_pads_i32(NDimSpatial); + std::vector input_right_pads_i32(NDimSpatial); + + for(ck::index_t d = 0; d < NDimSpatial; d++) + { + input_spatial_lengths_i32[d] = + static_cast(conv_param.input_spatial_lengths_[d]); + filter_spatial_lengths_i32[d] = + static_cast(conv_param.filter_spatial_lengths_[d]); + output_spatial_lengths_i32[d] = + static_cast(conv_param.GetOutputSpatialLengths()[d]); + conv_filter_strides_i32[d] = static_cast(conv_param.conv_filter_strides_[d]); + conv_filter_dilations_i32[d] = + static_cast(conv_param.conv_filter_dilations_[d]); + input_left_pads_i32[d] = static_cast(conv_param.input_left_pads_[d]); + input_right_pads_i32[d] = static_cast(conv_param.input_right_pads_[d]); + } + std::cout << "input: " << input_host_result.mDesc << std::endl; std::cout << "weight: " << weight.mDesc << std::endl; std::cout << "output: " << output.mDesc << std::endl; @@ -161,16 +184,16 @@ bool profile_conv_bwd_data_impl(int do_verification, op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.N_, - conv_param.K_, - conv_param.C_, - conv_param.input_spatial_lengths_, - conv_param.filter_spatial_lengths_, - conv_param.output_spatial_lengths_, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, + static_cast(conv_param.N_), + static_cast(conv_param.K_), + static_cast(conv_param.C_), + input_spatial_lengths_i32, + filter_spatial_lengths_i32, + output_spatial_lengths_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, in_element_op, wei_element_op, out_element_op); diff --git a/profiler/include/profiler/profile_conv_fwd_impl.hpp b/profiler/include/profiler/profile_conv_fwd_impl.hpp index bc2eb25797..917e4c07fc 100644 --- a/profiler/include/profiler/profile_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,6 +60,29 @@ bool profile_conv_fwd_impl(int do_verification, Tensor host_output(out_g_n_k_wos_desc); Tensor device_output(out_g_n_k_wos_desc); + std::vector input_spatial_lengths_i32(NDimSpatial); + std::vector filter_spatial_lengths_i32(NDimSpatial); + std::vector output_spatial_lengths_i32(NDimSpatial); + std::vector conv_filter_strides_i32(NDimSpatial); + std::vector conv_filter_dilations_i32(NDimSpatial); + std::vector input_left_pads_i32(NDimSpatial); + std::vector input_right_pads_i32(NDimSpatial); + + for(ck::index_t d = 0; d < NDimSpatial; d++) + { + input_spatial_lengths_i32[d] = + static_cast(conv_param.input_spatial_lengths_[d]); + filter_spatial_lengths_i32[d] = + static_cast(conv_param.filter_spatial_lengths_[d]); + output_spatial_lengths_i32[d] = + static_cast(conv_param.GetOutputSpatialLengths()[d]); + conv_filter_strides_i32[d] = static_cast(conv_param.conv_filter_strides_[d]); + conv_filter_dilations_i32[d] = + static_cast(conv_param.conv_filter_dilations_[d]); + input_left_pads_i32[d] = static_cast(conv_param.input_left_pads_[d]); + input_right_pads_i32[d] = static_cast(conv_param.input_right_pads_[d]); + } + std::cout << "input: " << input.mDesc << std::endl; std::cout << "weight: " << weight.mDesc << std::endl; std::cout << "output: " << host_output.mDesc << std::endl; @@ -143,16 +166,16 @@ bool profile_conv_fwd_impl(int do_verification, op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.N_, - conv_param.K_, - conv_param.C_, - conv_param.input_spatial_lengths_, - conv_param.filter_spatial_lengths_, - conv_param.GetOutputSpatialLengths(), - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, + static_cast(conv_param.N_), + static_cast(conv_param.K_), + static_cast(conv_param.C_), + input_spatial_lengths_i32, + filter_spatial_lengths_i32, + output_spatial_lengths_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, in_element_op, wei_element_op, out_element_op); diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index d913873305..f47d6f9889 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -33,7 +33,8 @@ template + typename BComputeType = AComputeType, + typename IndexType = ck::index_t> bool profile_grouped_conv_fwd_impl(int do_verification, int init_method, bool do_log, @@ -57,16 +58,16 @@ bool profile_grouped_conv_fwd_impl(int do_verification, const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); - std::array a_g_n_c_wis_lengths{}; - std::array a_g_n_c_wis_strides{}; - std::array b_g_k_c_xs_lengths{}; - std::array b_g_k_c_xs_strides{}; - std::array e_g_n_k_wos_lengths{}; - std::array e_g_n_k_wos_strides{}; - std::array conv_filter_strides{}; - std::array conv_filter_dilations{}; - std::array input_left_pads{}; - std::array input_right_pads{}; + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 577efafb1e..9397e2dac0 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -29,6 +29,12 @@ enum struct ConvDataType BF8_F8_F8, // 7 }; +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + #define OP_NAME "grouped_conv_fwd" #define OP_DESC "Grouped Convolution Forward" @@ -45,12 +51,13 @@ static void print_helper_msg() << " 5: Input bf8, Weight bf8, Output fp8\n" << " 6: Input fp8, Weight bf8, Output fp8\n" << " 7: Input bf8, Weight fp8, Output fp8)\n" - << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" - << "arg4: verification (0: no, 1: yes)\n" - << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" - << "arg6: print tensor value (0: no; 1: yes)\n" - << "arg7: time kernel (0: no, 1: yes)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; // clang-format on } @@ -60,7 +67,7 @@ static void print_helper_msg() int profile_grouped_conv_fwd(int argc, char* argv[]) { // 8 for control, 1 for num_dim_spatial - if(argc < 9) + if(argc < 10) { print_helper_msg(); return 1; @@ -68,20 +75,21 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) const auto data_type = static_cast(std::stoi(argv[2])); const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[4]); - const int init_method = std::stoi(argv[5]); - const bool do_log = std::stoi(argv[6]); - const bool time_kernel = std::stoi(argv[7]); - const int num_dim_spatial = std::stoi(argv[8]); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); - // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial - if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) { print_helper_msg(); return 1; } - const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); using F32 = float; using F16 = ck::half_t; @@ -138,18 +146,43 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using AComputeType = decltype(a_compute_type); using BComputeType = decltype(b_compute_type); - bool pass = ck::profiler::profile_grouped_conv_fwd_impl( - do_verification, init_method, do_log, time_kernel, params); + if(index_type == IndexType::INDEX_T) + { + bool pass = ck::profiler::profile_grouped_conv_fwd_impl( + do_verification, init_method, do_log, time_kernel, params); - return pass ? 0 : 1; + return pass ? 0 : 1; + } + else if(index_type == IndexType::LONG_INDEX_T) + { + bool pass = ck::profiler::profile_grouped_conv_fwd_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + } + else + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } }; // GNHWC_GKYXC_GNHWK diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 6922bbbcc7..124efb6b8a 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -24,12 +24,12 @@ class TestConvUtil : public ::testing::Test 128, 192, 256, - std::vector(ndims, 3), - std::vector(ndims, 71), - std::vector(ndims, s), - std::vector(ndims, d), - std::vector(ndims, p), - std::vector(ndims, p)); + std::vector(ndims, 3), + std::vector(ndims, 71), + std::vector(ndims, s), + std::vector(ndims, d), + std::vector(ndims, p), + std::vector(ndims, p)); } protected: @@ -48,35 +48,35 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) { // stride 2, dilation 1, pad 1 SetNDParams(1, 2, 1, 1); - std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{36}, "Error: ConvParams 1D.")); + out_spatial_len, std::vector{36}, "Error: ConvParams 1D.")); // stride 1, dilation 1, pad 1 SetNDParams(1, 1, 1, 1); out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}.")); + out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}.")); // stride 2, dilation 1, pad 2 SetNDParams(1, 2, 1, 2); out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{37}, + std::vector{37}, "Error: ConvParams 1D padding left/right {2}.")); // stride 2, dilation 2, pad 2 SetNDParams(1, 2, 2, 2); out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}.")); + out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}.")); // stride 3, dilation 2, pad 1 SetNDParams(1, 3, 2, 1); out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE( ck::utils::check_err(out_spatial_len, - std::vector{23}, + std::vector{23}, "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.")); } @@ -84,36 +84,38 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) { // stride 2, dilation 1, pad 1 SetNDParams(2, 2, 1, 1); - std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{36, 36}, + std::vector{36, 36}, "Error: ConvParams 2D default constructor.")); // stride 1, dilation 1, pad 1 SetNDParams(2, 1, 1, 1); out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}.")); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{71, 71}, + "Error: ConvParams 2D stride {1,1}.")); // stride 2, dilation 1, pad 2 SetNDParams(2, 2, 1, 2); out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{37, 37}, + std::vector{37, 37}, "Error: ConvParams 2D padding left/right {2,2}.")); // stride 2, dilation 2, pad 2 SetNDParams(2, 2, 2, 2); out_spatial_len = conv_params.GetOutputSpatialLengths(); - EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); + EXPECT_TRUE(ck::utils::check_err(out_spatial_len, + std::vector{36, 36}, + "Error: ConvParams 2D dilation {2,2}.")); // stride 3, dilation 2, pad 1 SetNDParams(2, 3, 2, 1); out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE( ck::utils::check_err(out_spatial_len, - std::vector{23, 23}, + std::vector{23, 23}, "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.")); } @@ -121,29 +123,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) { // stride 2, dilation 1, pad 1 SetNDParams(3, 2, 1, 1); - std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( - out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D.")); + out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D.")); // stride 1, dilation 1, pad 1 SetNDParams(3, 1, 1, 1); out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{71, 71, 71}, + std::vector{71, 71, 71}, "Error: ConvParams 3D stride {1, 1, 1}.")); // stride 2, dilation 1, pad 2 SetNDParams(3, 2, 1, 2); out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{37, 37, 37}, + std::vector{37, 37, 37}, "Error: ConvParams 3D padding left/right {2, 2, 2}.")); // stride 2, dilation 2, pad 2 SetNDParams(3, 2, 2, 2); out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, - std::vector{36, 36, 36}, + std::vector{36, 36, 36}, "Error: ConvParams 3D dilation {2, 2, 2}.")); // stride 3, dilation 2, pad 1 @@ -151,6 +153,6 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( out_spatial_len, - std::vector{23, 23, 23}, + std::vector{23, 23, 23}, "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}.")); } diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 1bfc183135..c86b18e77e 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -17,6 +17,7 @@ class TestGroupedConvndFwd : public ::testing::Test using InLayout = std::tuple_element_t<1, Tuple>; using WeiLayout = std::tuple_element_t<2, Tuple>; using OutLayout = std::tuple_element_t<3, Tuple>; + using IndexType = std::tuple_element_t<4, Tuple>; std::vector conv_params; @@ -33,7 +34,10 @@ class TestGroupedConvndFwd : public ::testing::Test OutLayout, DataType, DataType, - DataType>( + DataType, + DataType, + DataType, + IndexType>( true, // do_verification 1, // init_method: integer value false, // do_log @@ -46,30 +50,31 @@ class TestGroupedConvndFwd : public ::testing::Test using namespace ck::tensor_layout::convolution; -using KernelTypes1d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes1d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; -using KernelTypes2d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes2d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; -using KernelTypes3d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; -using KernelTypes2dLargeCases = ::testing::Types>; +using KernelTypes2dLargeCases = + ::testing::Types>; template class TestGroupedConvndFwd1d : public TestGroupedConvndFwd @@ -153,5 +158,8 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases) // With supported NumGroupsToMerge > 1 this->conv_params.push_back( {2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}}); + // When image is larger than 2GB + this->conv_params.push_back( + {2, 1, 1, 256, 256, {3, 3}, {4096, 2048}, {1024, 1024}, {3, 3}, {1, 1}, {1, 1}}); this->template Run<2>(); }