From 28f29667621e12e1e4f4fdeed9e7d05118031d95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= <38502616+bartekxk@users.noreply.github.com> Date: Sat, 6 Jun 2026 22:52:59 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7734 (commit 03ffb9d) [CK] Grouped Convolution Global Load/Store instances ## Motivation Support global load and store in grouped convolutions using instance factory. ## Technical Details - add new instances for each direction - add new tests for large cases ## Test Plan New test for large cases ## Test Result pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-1255 --- .../multi_index_transform_helper.hpp | 4 +- .../tensor_description/tensor_descriptor.hpp | 12 +- ...hread_group_tensor_slice_transfer_v4r2.hpp | 6 +- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 24 +- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 24 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 24 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp | 163 +++-- .../device_grouped_conv_bwd_weight_dl.hpp | 12 +- ...evice_grouped_conv_bwd_weight_explicit.hpp | 12 +- ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 12 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 56 +- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 12 +- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 12 +- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 12 +- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 42 +- ...wd_weight_xdl_waveletmodel_cshuffle_v3.hpp | 12 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 24 +- ...conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp | 24 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 24 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 49 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 24 +- ...ltiple_d_wmma_cshuffle_v3_large_tensor.hpp | 24 +- .../gpu/device/tensor_size_check.hpp | 4 +- .../gpu/grid/block_to_ctile_map.hpp | 36 +- .../gpu/grid/gridwise_elementwise_2d.hpp | 41 +- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 99 ++-- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 188 +++--- .../threadwise_tensor_slice_transfer_v3r2.hpp | 25 +- .../transform_conv_bwd_data_to_gemm_v1.hpp | 387 ++++++------ .../transform_conv_bwd_weight_to_gemm.hpp | 13 +- .../transform_conv_bwd_weight_to_gemm_v2.hpp | 17 +- .../transform_conv_fwd_to_gemm.hpp | 65 +- include/ck/utility/number.hpp | 33 ++ .../gpu/naive_conv_bwd_data_gpu.hpp | 545 ++++++++--------- .../gpu/naive_conv_bwd_weight_gpu.hpp | 437 +++++++------- .../gpu/naive_conv_fwd_gpu.hpp | 559 ++++++++---------- .../gpu/naive_conv_utils.hpp | 33 +- ..._grouped_conv_bwd_data_xdl_v3_instance.hpp | 94 +++ ...conv_bwd_weight_two_stage_xdl_instance.hpp | 42 ++ ...rouped_conv_bwd_weight_v3_xdl_instance.hpp | 77 +++ ...ice_grouped_conv_fwd_xdl_comp_instance.hpp | 107 ++++ .../gpu/grouped_convolution_backward_data.hpp | 12 + .../grouped_convolution_backward_data_xdl.inc | 91 +++ .../grouped_convolution_backward_weight.hpp | 26 + ...rouped_convolution_backward_weight_xdl.inc | 120 ++++ .../gpu/grouped_convolution_forward.hpp | 12 + .../grouped_convolution_forward_comp_xdl.inc | 96 +++ .../nhwgc/CMakeLists.txt | 3 + ...kyxc_nhwgk_bf16_large_tensors_instance.cpp | 39 ++ ...gkyxc_nhwgk_f16_large_tensors_instance.cpp | 39 ++ ...gkyxc_nhwgk_f32_large_tensors_instance.cpp | 39 ++ .../nhwgc/CMakeLists.txt | 5 + ...kyxc_nhwgk_bf16_large_tensors_instance.cpp | 40 ++ ...gkyxc_nhwgk_f16_large_tensors_instance.cpp | 40 ++ ...gk_bf16_default_large_tensors_instance.cpp | 38 ++ ...wgk_f16_default_large_tensors_instance.cpp | 38 ++ ...wgk_f32_default_large_tensors_instance.cpp | 38 ++ .../grouped_conv2d_fwd/nhwgc/CMakeLists.txt | 3 + ...nhwgk_bf16_comp_large_tensors_instance.cpp | 39 ++ ..._nhwgk_f16_comp_large_tensors_instance.cpp | 39 ++ ..._nhwgk_f32_comp_large_tensors_instance.cpp | 39 ++ .../ndhwgc/CMakeLists.txt | 4 + ...yxc_ndhwgk_bf16_large_tensors_instance.cpp | 39 ++ ...zyxc_ndhwgk_f16_large_tensors_instance.cpp | 39 ++ ...zyxc_ndhwgk_f32_large_tensors_instance.cpp | 39 ++ .../ndhwgc/CMakeLists.txt | 5 + ...yxc_ndhwgk_bf16_large_tensors_instance.cpp | 41 ++ ...zyxc_ndhwgk_f16_large_tensors_instance.cpp | 41 ++ ...gk_bf16_default_large_tensors_instance.cpp | 39 ++ ...wgk_f16_default_large_tensors_instance.cpp | 39 ++ ...wgk_f32_default_large_tensors_instance.cpp | 39 ++ .../grouped_conv3d_fwd/ndhwgc/CMakeLists.txt | 3 + ...dhwgk_bf16_comp_large_tensors_instance.cpp | 39 ++ ...ndhwgk_f16_comp_large_tensors_instance.cpp | 39 ++ ...ndhwgk_f32_comp_large_tensors_instance.cpp | 39 ++ .../profile_grouped_conv_bwd_weight_impl.hpp | 3 + ...ofile_grouped_conv_fwd_bias_clamp_impl.hpp | 11 +- ...profile_grouped_conv_fwd_bilinear_impl.hpp | 11 +- .../src/profile_grouped_conv_bwd_data.cpp | 14 +- .../src/profile_grouped_conv_bwd_weight.cpp | 14 +- profiler/src/profile_grouped_conv_fwd.cpp | 16 +- python/ck4inductor/grouped_conv_fwd/op.py | 1 + test/gpu_reference/gpu_reference_utils.hpp | 16 +- .../test_grouped_conv_bwd_data_bilinear.cpp | 8 +- ...st_grouped_convnd_bwd_data_large_cases.cpp | 104 ++-- test/grouped_convnd_bwd_weight/CMakeLists.txt | 4 + ...est_grouped_convnd_bwd_weight_bilinear.cpp | 8 +- ..._grouped_convnd_bwd_weight_dataset_xdl.cpp | 20 +- ..._grouped_convnd_bwd_weight_large_cases.cpp | 164 +++++ .../test_grouped_convnd_fwd_large_cases.cpp | 56 +- .../test_grouped_convnd_fwd_scaleadd_ab.cpp | 6 +- 91 files changed, 3469 insertions(+), 1638 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f32_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f32_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_large_tensors_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_large_tensors_instance.cpp create mode 100644 test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_large_cases.cpp diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index c7d64f620b..85aab7c4cf 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -55,8 +55,8 @@ template __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) { // Magic Division is not supported yet for int64_t - using IndexType = decltype(low_lengths.At(Number<0>{})); - if constexpr(std::is_same_v) + using IndexType = remove_cvref_t{}))>; + if constexpr(std::is_same_v || is_long_number_v) { return make_merge_transform_v1_carry_check(low_lengths); } diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 7b20b299a9..076a81d4f7 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -163,7 +163,7 @@ struct TensorDescriptor __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } template - __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const + __host__ __device__ constexpr auto CalculateOffset(const Idx& idx) const { static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); @@ -465,6 +465,8 @@ __host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc& const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack) { + using IndexType = remove_cvref_t{}])>; + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), "wrong! # of dimension inconsistent"); @@ -480,7 +482,7 @@ __host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc& auto is_non_zero_diff = make_zero_multi_index(); // decide do_transform by checkout non-zero index diff components - MultiIndex non_zero_diff_pick_visible; + MultiIndex non_zero_diff_pick_visible; static_for<0, ndim_visible, 1>{}( [&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); }); @@ -493,7 +495,7 @@ __host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc& const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up); - MultiIndex non_zero_diff_pick_low; + MultiIndex non_zero_diff_pick_low; // if any of upper index diff components is non-zero, then // 1) Need to do this transform @@ -529,6 +531,8 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens TensorCoord& coord, const TensorCoordStep& coord_step) { + using IndexType = remove_cvref_t; + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); @@ -562,7 +566,7 @@ __host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tens auto idx_low = get_container_subset(idx_hidden, dims_low); const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up); - MultiIndex idx_diff_low; + MultiIndex idx_diff_low; // HACK: control UpdateLowerIndex for Merge using hack constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp index 6ba2a0b917..6803b7ac74 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp @@ -40,7 +40,8 @@ template + index_t NumThreadScratch = 1, + typename IndexType = index_t> struct ThreadGroupTensorSliceTransfer_v4r2 { static constexpr index_t nDim = @@ -185,7 +186,8 @@ struct ThreadGroupTensorSliceTransfer_v4r2 DstsScalarStrideInVector, ThreadTransferSrcsResetCoordinateAfterRun, ThreadTransferDstsResetCoordinateAfterRun, - NumThreadScratch>; + NumThreadScratch, + IndexType>; ThreadwiseTransfer threadwise_transfer_; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index ca8c8b0c9b..82f8f2a213 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -793,11 +793,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const ck::index_t split_k = 1) { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; std::array a_g_n_k_wos_lengths_i32; std::array a_g_n_k_wos_strides_i32; std::array b_g_k_c_xs_lengths_i32; @@ -923,11 +925,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const ck::index_t split_k = 1) override { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; std::array a_g_n_k_wos_lengths_i32; std::array a_g_n_k_wos_strides_i32; std::array b_g_k_c_xs_lengths_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 3e77d1cf19..8dbbf1aa06 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -1921,11 +1921,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 const ck::index_t split_k = 1) { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; std::array a_g_n_k_wos_lengths_i32; std::array a_g_n_k_wos_strides_i32; @@ -2052,11 +2054,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 const ck::index_t split_k = 1) override { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; std::array a_g_n_k_wos_lengths_i32; std::array a_g_n_k_wos_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 199c5a951e..2b89a906bd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -2166,11 +2166,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const ck::index_t split_k = 1) { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; std::array a_g_n_k_wos_lengths_i32; std::array a_g_n_k_wos_strides_i32; @@ -2297,11 +2299,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const ck::index_t split_k = 1) override { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; std::array a_g_n_k_wos_lengths_i32; std::array a_g_n_k_wos_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp index c61078f50c..5798822bb3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp @@ -274,9 +274,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); - static_assert(std::is_same_v, "A not NGHWC"); - static_assert(std::is_same_v, "B not GKYXC"); - static_assert(std::is_same_v, "C not NGHWK"); + static_assert(std::is_same_v || + std::is_same_v, + "A not NGHWC"); + static_assert(std::is_same_v || + std::is_same_v, + "B not GKYXC"); + static_assert(std::is_same_v || + std::is_same_v, + "C not NGHWK"); // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this // implementation we can avoid copy data to workspace before kernel launch since number of @@ -292,6 +298,34 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3; + // Wave32 support: compute effective MXdlPerWave for wave64 and wave32 modes. + // The bwd_data template uses MRepeat/NRepeat as MXdlPerWave/NXdlPerWave and + // MPerXdl/NPerXdl (lowercase 'dl') instead of MPerXDL/NPerXDL. + template + static constexpr auto GetMXdlPerWave() + { + return GetXdlPerWave2(); + } + + static constexpr bool Wave32Force16MNPerXdl = sizeof(ADataType) == 2 && sizeof(BDataType) == 2; + static constexpr index_t Wave32MaxMNPerXdl = + Wave32Force16MNPerXdl ? 16 : math::max(MPerXdl, NPerXdl); + + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); + static constexpr auto MXdlPerWave32 = GetMXdlPerWave(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -334,7 +368,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - conv_to_gemm_transform.MakeCDescriptor_M_N(), 1, 1); + conv_to_gemm_transform.MakeCDescriptor_M_N(), IndexType{1}, IndexType{1}); return make_tuple(a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, @@ -353,8 +387,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 static constexpr bool ALdsScalarLoadToVgpr = false; static constexpr bool BLdsScalarLoadToVgpr = true; - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + // Parameterized GridwiseGemm template to support both wave64 (MPerXdl/NPerXdl) and + // wave32 (Wave32MaxMNPerXdl/Wave32MaxMNPerXdl) XDL instruction sizes. + template + using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor, @@ -373,10 +409,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 KPerBlock, AK1, BK1, - MPerXdl, - NPerXdl, - MRepeat, - NRepeat, + MPerXdl_, + NPerXdl_, + MRepeat_, + NRepeat*(NPerXdl / NPerXdl_), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -394,7 +430,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, + CShuffleNRepeatPerShuffle*(NPerXdl / NPerXdl_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, @@ -402,10 +438,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 AComputeType, BComputeType, DirectLoad, - ALdsScalarLoadToVgpr, - BLdsScalarLoadToVgpr, + DirectLoad && ALdsScalarLoadToVgpr, + DirectLoad && BLdsScalarLoadToVgpr, LargeTensors>; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = + GridwiseGemmBase; + // Default GridwiseGemm alias for use in non-wave-size-dependent code paths + using GridwiseGemm = GridwiseGemm64; + template static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) { @@ -632,8 +674,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); - const auto GemmM = a_grid_desc_m_k.GetLength(I0); - const auto GemmN = b_grid_desc_n_k.GetLength(I0); + const IndexType GemmM = a_grid_desc_m_k.GetLength(I0); + const IndexType GemmN = b_grid_desc_n_k.GetLength(I0); const auto e_grid_desc_mblock_mperblock_nblock_nperblock = GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -765,7 +807,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 { using Argument = DeviceOp::Argument; - template + template float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; @@ -783,7 +825,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 const index_t GemmM = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I0); const index_t GemmN = arg.b_grid_desc_n_k_container_[gemm_set_id].GetLength(I0); const index_t GemmK = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I1); - typename GridwiseGemm::Argument gemm_arg{ + typename GridwiseGemm_::Argument gemm_arg{ p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; const index_t gdx = arg.gemms_grid_size_[gemm_set_id]; @@ -816,7 +858,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 constexpr bool has_main_loop = has_main_k_block_loop_.value; constexpr bool no_main_loop = no_main_k_block_loop.value; const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, + GridwiseGemm_, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::EGridDesc_MPerBlock_NBlock_NPerBlock, @@ -868,17 +910,46 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 { arg.Print(); } - if(arg.k_batch_ > 1) + + if(get_warp_size() == 64) { - if constexpr(IsSplitKSupported) + if constexpr(MXdlPerWave64 > 0) { - ave_time += - RunMultiDGemm(arg, stream_config); + if(arg.k_batch_ > 1) + { + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm( + arg, stream_config); + } + } + else + { + ave_time += RunMultiDGemm( + arg, stream_config); + } } } else { - ave_time += RunMultiDGemm(arg, stream_config); + if constexpr(MXdlPerWave32 > 0) + { + if(arg.k_batch_ > 1) + { + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm( + arg, stream_config); + } + } + else + { + ave_time += RunMultiDGemm( + arg, stream_config); + } + } } return ave_time; @@ -901,6 +972,28 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 } } + if(get_warp_size() != 64) + { + // TODO: Enable for warp size 32 + return false; + } + // Reject if the current warp size has no valid XDL configuration + // Warp size 32 is temporary not supported but leave it for the future + if(get_warp_size() == 64) + { + if constexpr(MXdlPerWave64 == 0) + { + return false; + } + } + else + { + if constexpr(MXdlPerWave32 == 0) + { + return false; + } + } + // check device if constexpr(DirectLoad) { @@ -1257,11 +1350,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 else { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; std::array a_g_n_k_wos_lengths_i32; std::array a_g_n_k_wos_strides_i32; @@ -1471,11 +1566,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 else { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_c_wis_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_k_wos_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_c_wis_lengths) || ds_ovf; std::array a_g_n_k_wos_lengths_i32; std::array a_g_n_k_wos_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 136d59a160..6ea9e134aa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -1199,9 +1199,9 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths); std::array a_g_n_c_wis_lengths_i32; std::array a_g_n_c_wis_strides_i32; @@ -1304,9 +1304,9 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths); std::array a_g_n_c_wis_lengths_i32; std::array a_g_n_c_wis_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp index 2e165dd3d7..c2820495eb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp @@ -480,9 +480,9 @@ struct DeviceGroupedConvBwdWeight_Explicit OutElementwiseOperation out_element_op, const ck::index_t split_k) { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; @@ -585,9 +585,9 @@ struct DeviceGroupedConvBwdWeight_Explicit OutElementwiseOperation out_element_op, ck::index_t split_k) override { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 70befa866f..1b2f856225 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -1512,9 +1512,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 OutElementwiseOperation out_element_op, const ck::index_t split_k) { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; std::array e_g_k_c_xs_lengths_i32; @@ -1618,9 +1618,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 OutElementwiseOperation out_element_op, ck::index_t split_k) override { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 8a5bb93a8a..508e166d71 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -529,7 +529,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle Sequence, Sequence, I1, - I1>; + I1, + IndexType>; // NPerBlock is used for the first dim which is store dimension // (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector). // CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to NPerBlock so @@ -723,7 +724,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if(split_k < 0) { - ck::index_t gemmM, gemmN, gemmK; + IndexType gemmM, gemmN, gemmK; std::tie(gemmM, gemmN, gemmK) = get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); @@ -734,8 +735,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Ensure that k_batch_ does not exceed the maximum value // for the GEMM pipeline. - const auto k_batch_max = math::integer_divide_ceil(gemmK, KPerBlock); - k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); + const auto k_batch_max = + static_cast(math::integer_divide_ceil(gemmK, KPerBlock)); + k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { @@ -858,13 +860,16 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); } - const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); + const IndexType GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const IndexType GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; - compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideA_ = + static_cast(a_g_n_k_wos_strides_transposed[0]); + compute_ptr_offset_of_batch_.BatchStrideB_ = + static_cast(b_g_n_c_wis_strides_transposed[0]); + compute_ptr_offset_of_batch_.BatchStrideC_ = + static_cast(e_g_k_c_xs_strides_transposed[0]); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ce_grid_desc_m_n_, @@ -1025,9 +1030,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle template float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); - const index_t GemmK = + const IndexType GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const IndexType GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const IndexType GemmK = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); AccDataType* p_c_grid = type_convert(arg.p_workspace_); @@ -1050,7 +1055,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle typename GridwiseGemm::Argument gemm_arg{ p_a_grid, p_b_grid, p_c_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; - index_t gdx, gdy, gdz; + IndexType gdx, gdy, gdz; std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge); @@ -1697,8 +1702,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle auto launch_elementwise_kernel = [&]() { const AccDataType* p_c_grid = type_convert(arg.p_workspace_); - std::array in_out_batch_strides = { - static_cast(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; + std::array in_out_batch_strides = { + static_cast(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) @@ -1741,7 +1746,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle Block2TileMapElementwise, CDEElementwiseOperation, I1, - I1>; + I1, + IndexType>; return launch_and_time_kernel(stream_config, kernel, @@ -1839,9 +1845,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle return false; } - const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); - const index_t GemmK = + const IndexType GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const IndexType GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const IndexType GemmK = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); if constexpr(is_same_v || is_same_v) @@ -2218,9 +2224,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } else { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; @@ -2399,9 +2405,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } else { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index 60679e4924..dc211f6cc8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -828,9 +828,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle OutElementwiseOperation out_element_op, const index_t split_k) { - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths); std::array a_g_n_c_wis_lengths_i32; std::array a_g_n_c_wis_strides_i32; @@ -933,9 +933,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle OutElementwiseOperation out_element_op, ck::index_t split_k) override { - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths); std::array a_g_n_c_wis_lengths_i32; std::array a_g_n_c_wis_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 448327a809..b82b59d22b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -1356,9 +1356,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 OutElementwiseOperation out_element_op, const ck::index_t split_k) { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; std::array e_g_k_c_xs_lengths_i32; @@ -1462,9 +1462,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 OutElementwiseOperation out_element_op, ck::index_t split_k) override { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 08a3cebe55..b6cacd2197 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -1366,9 +1366,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle OutElementwiseOperation out_element_op, const ck::index_t split_k) { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; std::array e_g_k_c_xs_lengths_i32; @@ -1472,9 +1472,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle OutElementwiseOperation out_element_op, ck::index_t split_k) override { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 54c5bfb3cf..c784827828 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -184,8 +184,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; DispatchSplitKHack_2Lds float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); - const index_t GemmK = + const IndexType GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const IndexType GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const IndexType GemmK = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); const ADataType* p_a_grid = arg.p_a_grid_; @@ -1770,9 +1770,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } else { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; @@ -1951,9 +1951,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } else { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; @@ -2009,6 +2009,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 { auto str = std::stringstream(); + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + // clang-format off str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; @@ -2042,6 +2053,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 << CBlockTransferScalarPerVector_NWaveNPerXdl; if constexpr(NumGroupsToMerge > 1) str << ", " << NumGroupsToMerge; + if constexpr(LargeTensors) { + // Should be added for all instances but due to backward compatiblity, + // there is a lack of this information for other instances than Large + // Tensors. + str << ", BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer]; + } str << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_waveletmodel_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_waveletmodel_cshuffle_v3.hpp index aeeb4765c2..d1dd317463 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_waveletmodel_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_waveletmodel_cshuffle_v3.hpp @@ -947,9 +947,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_WaveletModel_CShuffleV3 OutElementwiseOperation out_element_op, const ck::index_t split_k) { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; std::array e_g_k_c_xs_lengths_i32; @@ -1049,9 +1049,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_WaveletModel_CShuffleV3 OutElementwiseOperation out_element_op, ck::index_t split_k) override { - const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || - tensor_exceeds_2gb(e_g_k_c_xs_lengths) || - tensor_exceeds_2gb(a_g_n_k_wos_lengths); + const bool stride_ovf = tensor_exceeds_2gb(b_g_n_c_wis_lengths) || + tensor_exceeds_2gb(e_g_k_c_xs_lengths) || + tensor_exceeds_2gb(a_g_n_k_wos_lengths); std::array b_g_n_c_wis_lengths_i32; std::array b_g_n_c_wis_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index fb6bf7c378..e7716e773b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -900,11 +900,13 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK array_convert(input_right_pads_i32, input_right_pads); bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return Argument{p_a, p_b, p_ds, @@ -1024,11 +1026,13 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK array_convert(input_right_pads_i32, input_right_pads); bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return std::make_unique(p_a, p_b, p_ds, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index 9b0358e5d2..5d28ced866 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -2160,11 +2160,13 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 array_convert(input_right_pads_i32, input_right_pads); bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return Argument{p_as, p_bs, p_ds, @@ -2284,11 +2286,13 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 array_convert(input_right_pads_i32, input_right_pads); bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return std::make_unique(p_as, p_bs, p_ds, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 21eebd4481..8c61ac21e6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -2015,11 +2015,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle array_convert(input_right_pads_i32, input_right_pads); bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return Argument{p_as, p_bs, p_ds, @@ -2140,11 +2142,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle array_convert(input_right_pads_i32, input_right_pads); bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return std::make_unique(p_as, p_bs, p_ds, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 5daec60ad4..8a7fd8bb30 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -1008,15 +1008,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); - const index_t GemmK = + const IndexType GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const IndexType GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const IndexType GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); const auto num_workgroups_per_Conv_N = arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; - index_t gdx, gdy, gdz; + IndexType gdx, gdy, gdz; // TODO: Do we want to support kbatch ?? std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); @@ -1814,16 +1814,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } // Gridwise gemm v3 doesn't verify descriptors size - if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB()) + if(!LargeTensors) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB()) { - std::cout - << "[conv_to_gemm_transformer_] One of the descriptors is bigger than 2GB!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "[conv_to_gemm_transformer_] One of the descriptors is bigger than 2GB!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; } - return false; } // check Gridwise GEMM @@ -2031,11 +2034,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 else { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; std::array a_g_n_c_wis_lengths_i32; std::array a_g_n_c_wis_strides_i32; @@ -2237,11 +2242,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 else { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; std::array a_g_n_c_wis_lengths_i32; std::array a_g_n_c_wis_strides_i32; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 737c019edd..0b08369b9d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -845,11 +845,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle array_convert(input_right_pads_i32, input_right_pads); bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return Argument{p_a, p_b, p_ds, @@ -973,11 +975,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle array_convert(input_right_pads_i32, input_right_pads); bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return std::make_unique(p_a, p_b, p_ds, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp index af14aa96ea..76d01c5a99 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp @@ -1311,11 +1311,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor const CDEElementwiseOperation& cde_element_op) { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return Argument{p_a, p_b, p_ds, @@ -1438,11 +1440,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor const CDEElementwiseOperation& cde_element_op) override { bool ds_ovf = false; - for(index_t d = 0; d < NumDTensor; d++) - ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[d]); - const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || - tensor_exceeds_2gb(b_g_k_c_xs_lengths) || - tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_ovf |= tensor_exceeds_2gb(ds_g_n_k_wos_lengths[i]); + }); + const bool stride_ovf = tensor_exceeds_2gb(a_g_n_c_wis_lengths) || + tensor_exceeds_2gb(b_g_k_c_xs_lengths) || + tensor_exceeds_2gb(e_g_n_k_wos_lengths) || ds_ovf; return std::make_unique(p_a, p_b, p_ds, diff --git a/include/ck/tensor_operation/gpu/device/tensor_size_check.hpp b/include/ck/tensor_operation/gpu/device/tensor_size_check.hpp index 723b0bfb14..5056f0c898 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_size_check.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_size_check.hpp @@ -9,11 +9,11 @@ namespace ck { namespace tensor_operation { namespace device { -template +template bool tensor_exceeds_2gb(const Lengths& lengths) { constexpr long_index_t TwoGB = (long_index_t{1} << 31); - long_index_t total = 1; + long_index_t total = sizeof(DataType); for(const auto& l : lengths) total *= l; return total > TwoGB; diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 73a1cf9df1..17294beb8e 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -271,21 +271,21 @@ struct BlockToCTileMap_M00_N0_M01Adapt // Grouped Rows of column-vectors WGP mapping // Optimized for gfx94x-like multipe-die chip -template +template struct BlockToCTileMap_Grouped_M00_N0_M01Adapt { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default; - __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, - index_t N, - index_t M01 = 8) + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(IndexType M, + IndexType N, + IndexType M01 = 8) : M_(M), N_(N), M01_(M01) { } - __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + __host__ __device__ static constexpr IndexType CalculateGridSize(IndexType M, IndexType N) { const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock); @@ -302,18 +302,18 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt template __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const { - auto block_1d_id = idx_top[I0]; + IndexType block_1d_id = idx_top[I0]; - const auto M0 = math::integer_divide_ceil(M_, MPerBlock); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + const IndexType M0 = math::integer_divide_ceil(M_, MPerBlock); + const IndexType N0 = math::integer_divide_ceil(N_, NPerBlock); if(M0 == 1) { - return make_tuple(0, block_1d_id); + return make_tuple(IndexType{0}, block_1d_id); } else if(N0 == 1) { - return make_tuple(block_1d_id, 0); + return make_tuple(block_1d_id, IndexType{0}); } // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index else @@ -327,14 +327,14 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt ? group_id_x * group_size + group_id_y : group_id_x * group_size + big_group_num - group_id_x + group_id_y; - index_t idx_N0 = remap_block_1d_id % N0; - index_t idx_M0 = remap_block_1d_id / N0; + IndexType idx_N0 = remap_block_1d_id % N0; + IndexType idx_M0 = remap_block_1d_id / N0; const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + IndexType idx_M00 = idx_M0 / M01_; + IndexType idx_M01 = idx_M0 % M01_; + IndexType idx_N0_M01_local = idx_N0 + idx_M01 * N0; /** * idxN0 @@ -393,9 +393,9 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt } private: - index_t M_; - index_t N_; - index_t M01_; + IndexType M_; + IndexType N_; + IndexType M01_; }; // columns of row-vectors diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index ca0372b521..c0abda25c3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -213,7 +213,8 @@ template + index_t NumOutputs, + typename IndexType = index_t> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -224,18 +225,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, - const index_t batch_count, - const std::array input_batch_strides, - const std::array output_batch_strides) + const IndexType batch_count, + const std::array input_batch_strides, + const std::array output_batch_strides) { static_assert(InGridDescTuple::Size() == NumInputs && InDataTypePointerTuple::Size() == NumInputs); static_assert(OutGridDescTuple::Size() == NumOutputs && OutDataTypePointerTuple::Size() == NumOutputs); - const index_t num_blocks_per_batch = + const IndexType num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const IndexType g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); InDataTypePointerTuple p_in_global_with_offset_tuple; OutDataTypePointerTuple p_out_global_with_offset_tuple; @@ -273,7 +275,8 @@ template + index_t DstVectorDim, + typename IndexType = index_t> struct GridwiseElementwise { static constexpr index_t NumInput = InDataTypePointerTuple::Size(); @@ -322,15 +325,19 @@ struct GridwiseElementwise const auto in_global_buf_tuple = generate_tuple( [&](auto I) { - return make_dynamic_buffer( - p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize()); + return make_dynamic_buffer(p_in_global_tuple[I], + in_grid_desc_tuple[I].GetElementSpaceSize()); }, Number{}); auto out_global_buf_tuple = generate_tuple( [&](auto I) { - return make_dynamic_buffer( - p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize()); + return make_dynamic_buffer(p_out_global_tuple[I], + out_grid_desc_tuple[I].GetElementSpaceSize()); }, Number{}); @@ -386,11 +393,13 @@ struct GridwiseElementwise uniform_sequence_gen_t, uniform_sequence_gen_t, uniform_sequence_gen_t, - uniform_sequence_gen_t>{in_grid_desc_tuple, - input_thread_grid_offset, - out_grid_desc_tuple, - output_thread_grid_offset, - elementwise_op}; + uniform_sequence_gen_t, + 1, + IndexType>{in_grid_desc_tuple, + input_thread_grid_offset, + out_grid_desc_tuple, + output_thread_grid_offset, + elementwise_op}; global_to_global_transfer.Run( in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index c9434ebfa2..a300b817e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -210,59 +210,59 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 is_single_rate_mfma, is_scale_mfma>::selected_mfma.k_per_blk); - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + __host__ static auto CalculateGridSize(IndexType M, IndexType N, index_t KBatch, index_t Batch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); } - __host__ static auto CalculateMPadded(index_t M) + __host__ static IndexType CalculateMPadded(IndexType M) { return math::integer_least_multiple(M, MPerBlock); } - __host__ static auto CalculateNPadded(index_t N) + __host__ static IndexType CalculateNPadded(IndexType N) { return math::integer_least_multiple(N, NPerBlock); } - __host__ static auto CalculateKPadded(index_t K) + __host__ static IndexType CalculateKPadded(IndexType K) { return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; } - __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + __host__ static IndexType CalculateAK0Padded(IndexType K, IndexType K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); } - __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + __host__ static IndexType CalculateBK0Padded(IndexType K, IndexType K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); } - __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + __host__ static IndexType CalculateKPadded(IndexType K, IndexType K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * KPerBlock; } - __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + __host__ static IndexType CalculateKRead(IndexType K, IndexType K_Batch = 1) { constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); auto K_t = K_Batch * KReadVec; return (K + K_t - 1) / K_t * KReadVec; } - __host__ static auto CalculateMBlock(index_t M) + __host__ static IndexType CalculateMBlock(IndexType M) { - return math::integer_divide_ceil(M, MPerBlock); + return math::integer_divide_ceil(M, static_cast(MPerBlock)); } - __host__ static auto CalculateNBlock(index_t N) + __host__ static IndexType CalculateNBlock(IndexType N) { - return math::integer_divide_ceil(N, NPerBlock); + return math::integer_divide_ceil(N, static_cast(NPerBlock)); } template @@ -379,13 +379,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 struct Problem { - __host__ Problem(index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - index_t StrideC_, - index_t KBatch_) + __host__ Problem(IndexType M_, + IndexType N_, + IndexType K_, + IndexType StrideA_, + IndexType StrideB_, + IndexType StrideC_, + IndexType KBatch_) : M{M_}, N{N_}, K{K_}, @@ -414,21 +414,21 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 << "NBlock: " << NBlock << "}" << std::endl; } - index_t M; - index_t N; - index_t K; - index_t StrideA; - index_t StrideB; - index_t StrideC; - index_t KBatch; - index_t MPadded; - index_t NPadded; - index_t KRead; - index_t KPadded; - index_t AK0; - index_t BK0; - index_t MBlock; - index_t NBlock; + IndexType M; + IndexType N; + IndexType K; + IndexType StrideA; + IndexType StrideB; + IndexType StrideC; + IndexType KBatch; + IndexType MPadded; + IndexType NPadded; + IndexType KRead; + IndexType KPadded; + IndexType AK0; + IndexType BK0; + IndexType MBlock; + IndexType NBlock; }; // Argument @@ -437,13 +437,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 __host__ Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - index_t StrideC_, - index_t k_batch_) + IndexType M_, + IndexType N_, + IndexType K_, + IndexType StrideA_, + IndexType StrideB_, + IndexType StrideC_, + IndexType k_batch_) : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, @@ -599,14 +599,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); } + template + using NumberType = + std::conditional_t, Number, LongNumber>; + template __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + const CGridDesc& c_grid_desc_m_n, IndexType MBlock, IndexType NBlock) { const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), - make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(make_unmerge_transform(make_tuple(MBlock, NumberType{})), + make_unmerge_transform(make_tuple(NBlock, NumberType{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); @@ -615,7 +619,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // return block_id to C matrix tile idx (m0, n0) mapping // if arch = gfx942 - using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using Block2CTileMap = + BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock, IndexType>; template ( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; const BElementwiseOperation b_element_op{}; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 99e5828661..dc53f323b5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -312,57 +312,57 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 is_single_rate_mfma, is_scale_mfma>::selected_mfma.k_per_blk); - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + __host__ static auto CalculateGridSize(IndexType M, IndexType N, IndexType KBatch) { return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); } - __host__ __device__ static auto CalculateMPadded(index_t M) + __host__ __device__ static IndexType CalculateMPadded(IndexType M) { return math::integer_least_multiple(M, MPerBlock); } - __host__ __device__ static auto CalculateNPadded(index_t N) + __host__ __device__ static IndexType CalculateNPadded(IndexType N) { return math::integer_least_multiple(N, NPerBlock); } - __host__ __device__ static auto CalculateKPadded(index_t K) + __host__ __device__ static IndexType CalculateKPadded(IndexType K) { return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; } - __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + __host__ __device__ static IndexType CalculateAK0Padded(IndexType K, IndexType K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); } - __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + __host__ __device__ static IndexType CalculateBK0Padded(IndexType K, IndexType K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); } - __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + __host__ __device__ static IndexType CalculateKPadded(IndexType K, IndexType K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * KPerBlock; } - __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + __host__ __device__ static IndexType CalculateKRead(IndexType K, IndexType K_Batch = 1) { constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); auto K_t = K_Batch * KReadVec; return (K + K_t - 1) / K_t * KReadVec; } - __host__ __device__ static auto CalculateMBlock(index_t M) + __host__ __device__ static IndexType CalculateMBlock(IndexType M) { return math::integer_divide_ceil(M, MPerBlock); } - __host__ __device__ static auto CalculateNBlock(index_t N) + __host__ __device__ static IndexType CalculateNBlock(IndexType N) { return math::integer_divide_ceil(N, NPerBlock); } @@ -683,14 +683,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 struct Problem { __host__ __device__ Problem() = default; - __host__ __device__ Problem(index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - std::array StrideDs_, - index_t StrideC_, - index_t KBatch_) + __host__ __device__ Problem(IndexType M_, + IndexType N_, + IndexType K_, + IndexType StrideA_, + IndexType StrideB_, + std::array StrideDs_, + IndexType StrideC_, + IndexType KBatch_) : M{M_}, N{N_}, K{K_}, @@ -720,22 +720,22 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 << "NBlock: " << NBlock << "}" << std::endl; } - index_t M; - index_t N; - index_t K; - index_t StrideA; - index_t StrideB; - std::array StrideDs; - index_t StrideC; - index_t KBatch; - index_t MPadded; - index_t NPadded; - index_t KRead; - index_t KPadded; - index_t AK0; - index_t BK0; - index_t MBlock; - index_t NBlock; + IndexType M; + IndexType N; + IndexType K; + IndexType StrideA; + IndexType StrideB; + std::array StrideDs; + IndexType StrideC; + IndexType KBatch; + IndexType MPadded; + IndexType NPadded; + IndexType KRead; + IndexType KPadded; + IndexType AK0; + IndexType BK0; + IndexType MBlock; + IndexType NBlock; }; // Argument @@ -746,14 +746,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 const BDataType* p_b_grid_, std::array p_ds_grid_, CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - std::array StrideDs_, - index_t StrideC_, - index_t k_batch_, + IndexType M_, + IndexType N_, + IndexType K_, + IndexType StrideA_, + IndexType StrideB_, + std::array StrideDs_, + IndexType StrideC_, + IndexType k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_) @@ -1319,7 +1319,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // return block_id to C matrix tile idx (m0, n0) mapping // if arch = gfx942 - using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using Block2CTileMapDefault = + BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock, IndexType>; template (blockwise_gemm_pipeline, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_thread_buf, - block_m_id, - block_n_id, - p_shared, - p_ds_grid, - p_c_grid, - c_element_op); + if constexpr(LargeTensors) + { + static_assert(NumDTensor == 0, "Not implemented"); + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_c_grid, + c_element_op); + } + else + { + // shuffle C and write out + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMultiDEpilogue(blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_ds_grid, + p_c_grid, + c_element_op); + } } template (blockwise_gemm_pipeline, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_thread_buf, - block_m_id, - block_n_id, - p_shared_0, - p_ds_grid, - p_c_grid, - c_element_op); + if constexpr(LargeTensors) + { + static_assert(NumDTensor == 0, "Not implemented"); + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_c_grid, + c_element_op); + } + else + { + // shuffle C and write out + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMultiDEpilogue(blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_ds_grid, + p_c_grid, + c_element_op); + } } }; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp index 24fbd66be6..fdff6e689e 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -40,7 +40,8 @@ template + index_t NumThreadScratch = 1, + typename IndexType = index_t> struct ThreadwiseTensorSliceTransfer_v3r2 { static constexpr index_t nDim = SliceLengths::Size(); @@ -57,8 +58,9 @@ struct ThreadwiseTensorSliceTransfer_v3r2 enable_if_t = false> static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) { - return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, - Number{}); + return generate_tuple( + [&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); } using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); @@ -81,8 +83,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2 const Indices& src_slice_origin_idxs) { static_for<0, nSrc, 1>{}([&](auto src_i) { - src_coords_(src_i) = - make_tensor_coordinate(src_descs.At(src_i), src_slice_origin_idxs[src_i]); + src_coords_(src_i) = make_tensor_coordinate(src_descs.At(src_i), + src_slice_origin_idxs[src_i]); }); } @@ -91,8 +93,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2 const Indices& dst_slice_origin_idxs) { static_for<0, nDst, 1>{}([&](auto dst_i) { - dst_coords_(dst_i) = - make_tensor_coordinate(dst_descs.At(dst_i), dst_slice_origin_idxs[dst_i]); + dst_coords_(dst_i) = make_tensor_coordinate(dst_descs.At(dst_i), + dst_slice_origin_idxs[dst_i]); }); } @@ -172,10 +174,10 @@ struct ThreadwiseTensorSliceTransfer_v3r2 SrcsScalarPerVector::At(src_i)>; using src_vector_t = typename src_vector_type::type; + const IndexType ld_offset = src_coords_.At(src_i).GetOffset(); // copy data from src_buf into src_vector_container - auto src_vector_container = - src_vector_type{src_bufs.At(src_i).template Get( - src_coords_.At(src_i).GetOffset(), is_src_valid)}; + auto src_vector_container = src_vector_type{ + src_bufs.At(src_i).template Get(ld_offset, is_src_valid)}; // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) @@ -338,8 +340,9 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static_cast(DstInMemOps::At(dst_i.value)); // copy data from dst_vector_container to dst_buf + const IndexType st_offset = dst_coords_.At(dst_i).GetOffset(); dst_bufs.At(dst_i).template Update( - dst_coords_.At(dst_i).GetOffset(), + st_offset, is_dst_valid, dst_vector_container.template AsType()[Helper::I0]); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 088604b028..5976fc4c19 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -13,40 +13,50 @@ namespace ck { namespace tensor_operation { -/** - * @brief Enable custom tensor transform for convolution backward data output. - * - * When set to 1, this macro enables a custom transformation of the output tensor - * in convolution backward data operations. - */ -#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1 - template < index_t NDimSpatial, ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, - index_t AK1, - index_t BK1, - index_t GemmMPerBlock, - index_t GemmNPerBlock, - index_t GemmKPerBlock, + index_t AK1_, + index_t BK1_, + index_t GemmMPerBlock_, + index_t GemmNPerBlock_, + index_t GemmKPerBlock_, bool DoPadGemmM, bool DoPadGemmN, typename ALayout, typename BLayout, typename CLayout, - bool SplitN = false, - typename ADataType = float, - typename CDataType = float, - index_t NumGroupsToMerge = 1, - typename IndexType = index_t, - bool CTranspose = false> + bool SplitN = false, + typename ADataType = float, + typename CDataType = float, + index_t NumGroupsToMerge_ = 1, + typename IndexType = index_t, + bool CTranspose = false> struct TransformConvBwdDataToGemm_v1 { private: - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; + /** + * @brief Enable custom tensor transform for convolution backward data output. + * + * When set to 1, this variable enables a custom transformation of the output tensor + * in convolution backward data operations. + */ + static constexpr bool CustomTensorTransformBwdData = std::is_same_v; + + template + using NumberType = + std::conditional_t, Number, LongNumber>; + static constexpr auto I0 = NumberType<0>{}; + static constexpr auto I1 = NumberType<1>{}; + static constexpr auto I2 = NumberType<2>{}; + static constexpr auto I3 = NumberType<3>{}; + + static constexpr IndexType AK1 = static_cast(AK1_); + static constexpr IndexType BK1 = static_cast(BK1_); + static constexpr IndexType GemmMPerBlock = static_cast(GemmMPerBlock_); + static constexpr IndexType GemmNPerBlock = static_cast(GemmNPerBlock_); + static constexpr IndexType GemmKPerBlock = static_cast(GemmKPerBlock_); + static constexpr IndexType NumGroupsToMerge = static_cast(NumGroupsToMerge_); static constexpr auto NonSpatialDimsNum = Number<3>{}; @@ -197,7 +207,7 @@ struct TransformConvBwdDataToGemm_v1 ZDot_{static_cast(transform_conv_bwd_data_to_gemm_base.ZDot_)}, YDot_{static_cast(transform_conv_bwd_data_to_gemm_base.YDot_)}, XDot_{static_cast(transform_conv_bwd_data_to_gemm_base.XDot_)}, - batch_k_{transform_conv_bwd_data_to_gemm_base.batch_k_} + batch_k_{static_cast(transform_conv_bwd_data_to_gemm_base.batch_k_)} { } @@ -278,21 +288,24 @@ struct TransformConvBwdDataToGemm_v1 } else { - Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = 1; - InLeftPadD_ = InRightPadD_ = DiStride_ = DoStride_ = IdxZTilde_ = 0; + Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = static_cast(1); + InLeftPadD_ = InRightPadD_ = DiStride_ = DoStride_ = IdxZTilde_ = + static_cast(0); } - GcdStrideDilationH_ = math::gcd(ConvStrideH_, ConvDilationH_); - GcdStrideDilationW_ = math::gcd(ConvStrideW_, ConvDilationW_); + GcdStrideDilationH_ = static_cast(math::gcd(ConvStrideH_, ConvDilationH_)); + GcdStrideDilationW_ = static_cast(math::gcd(ConvStrideW_, ConvDilationW_)); YTilde_ = ConvStrideH_ / GcdStrideDilationH_; XTilde_ = ConvStrideW_ / GcdStrideDilationW_; - HTilde_ = Ho_ + math::integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_); - WTilde_ = Wo_ + math::integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_); + HTilde_ = static_cast( + Ho_ + math::integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_)); + WTilde_ = static_cast( + Wo_ + math::integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_)); - YDot_ = math::integer_divide_ceil(Y_, YTilde_); - XDot_ = math::integer_divide_ceil(X_, XTilde_); + YDot_ = static_cast(math::integer_divide_ceil(Y_, YTilde_)); + XDot_ = static_cast(math::integer_divide_ceil(X_, XTilde_)); } #if 0 // At now not supported to split tensor @@ -665,8 +678,8 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - const index_t K0PerBlock = GemmKPerBlock / AK1; - const index_t AK0 = + const IndexType K0PerBlock = GemmKPerBlock / AK1; + const IndexType AK0 = math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock; // A: output tensor @@ -713,122 +726,132 @@ struct TransformConvBwdDataToGemm_v1 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // GemmK is different for each GEMM - const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); - const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); - const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + const auto ZDotSlice = + static_cast(math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_)); + const auto YDotSlice = + static_cast(math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_)); + const auto XDotSlice = + static_cast(math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_)); if constexpr(NDimSpatial == 2) { - const index_t K0PerBlock = GemmKPerBlock / AK1; - const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_, - AK1 * K0PerBlock * batch_k_) * - K0PerBlock; + const IndexType K0PerBlock = GemmKPerBlock / AK1; + const IndexType AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_, + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; -#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0 - // A: output tensor - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple( - make_pass_through_transform(N_), - make_embed_transform(make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + if constexpr(!CustomTensorTransformBwdData) + { + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{})); + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), - make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform( + make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto out_gemmk_gemmm_padded_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - out_gemmk_gemmmraw_grid_desc, - make_tuple(GemmKPerBlock, GemmMPerBlock), - Sequence{}); + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); - const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( - out_gemmk_gemmm_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), - make_pass_through_transform( - out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - return out_gemmak0_gemmm_gemmak1_grid_desc; -#else - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_conv_bwd_data_out_transform(N_, - Ho_, - Wo_, - K_, - YDot_, - XDot_, - HTilde_, - WTilde_, - ConvDilationH_, - ConvDilationW_, - HTildeSlice, - WTildeSlice, - YDotSlice, - XDotSlice, - IHTildeSliceBegin, - IWTildeSliceBegin, - GcdStrideDilationH_, - GcdStrideDilationW_, - AK0 * batch_k_, - AK1, - GemmMPerBlock, - GemmKPerBlock)), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0, 1, 2>{})); + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); - return out_n_hop_wop_k_grid_desc_final; -#endif + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(Number<1>{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else + { + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_conv_bwd_data_out_transform(N_, + Ho_, + Wo_, + K_, + YDot_, + XDot_, + HTilde_, + WTilde_, + ConvDilationH_, + ConvDilationW_, + HTildeSlice, + WTildeSlice, + YDotSlice, + XDotSlice, + IHTildeSliceBegin, + IWTildeSliceBegin, + GcdStrideDilationH_, + GcdStrideDilationW_, + AK0 * batch_k_, + AK1, + GemmMPerBlock, + GemmKPerBlock)), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0, 1, 2>{})); + + return out_n_hop_wop_k_grid_desc_final; + } } else if constexpr(NDimSpatial == 3) { @@ -915,17 +938,17 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); - const index_t K0PerBlock = GemmKPerBlock / AK1; - const index_t AK0 = - math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), - AK1 * K0PerBlock * batch_k_) * - K0PerBlock; + const IndexType K0PerBlock = GemmKPerBlock / AK1; + const IndexType AK0 = math::integer_divide_ceil( + out_gemmk_gemmm_padded_grid_desc.GetLength(Number<0>{}), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), make_pass_through_transform( - out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + out_gemmk_gemmm_padded_grid_desc.GetLength(Number<1>{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -953,8 +976,8 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - const index_t K0PerBlock = GemmKPerBlock / BK1; - const index_t BK0 = + const IndexType K0PerBlock = GemmKPerBlock / BK1; + const IndexType BK0 = math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock; // B: weight tensor @@ -983,9 +1006,12 @@ struct TransformConvBwdDataToGemm_v1 const auto wei_grid_desc = MakeWeiGridDesc(); // GemmK is different for each GEMM - const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); - const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); - const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + const auto ZDotSlice = + static_cast(math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_)); + const auto YDotSlice = + static_cast(math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_)); + const auto XDotSlice = + static_cast(math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_)); // B weight tensor if constexpr(NDimSpatial == 2) @@ -1036,17 +1062,17 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); - const index_t K0PerBlock = GemmKPerBlock / BK1; - const index_t BK0 = - math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), - BK1 * K0PerBlock * batch_k_) * - K0PerBlock; + const IndexType K0PerBlock = GemmKPerBlock / BK1; + const IndexType BK0 = math::integer_divide_ceil( + wei_gemmk_gemmn_padded_grid_desc.GetLength(Number<0>{}), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( wei_gemmk_gemmn_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), make_pass_through_transform( - wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + wei_gemmk_gemmn_padded_grid_desc.GetLength(Number<1>{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -1121,17 +1147,17 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); - const index_t K0PerBlock = GemmKPerBlock / BK1; - const index_t BK0 = - math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), - BK1 * K0PerBlock * batch_k_) * - K0PerBlock; + const IndexType K0PerBlock = GemmKPerBlock / BK1; + const IndexType BK0 = math::integer_divide_ceil( + wei_gemmk_gemmn_padded_grid_desc.GetLength(Number<0>{}), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( wei_gemmk_gemmn_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), make_pass_through_transform( - wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + wei_gemmk_gemmn_padded_grid_desc.GetLength(Number<1>{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -1246,18 +1272,15 @@ struct TransformConvBwdDataToGemm_v1 { // only work on DTilde, HTilde and WTilde that contribute to // non-padding area of input tensor - const auto IDTildeSliceBegin = - math::integer_divide_floor(math::max(static_cast(I0), - InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), - ConvStrideD_); - const auto IHTildeSliceBegin = - math::integer_divide_floor(math::max(static_cast(I0), - InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), - ConvStrideH_); - const auto IWTildeSliceBegin = - math::integer_divide_floor(math::max(static_cast(I0), - InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), - ConvStrideW_); + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(IndexType{0}, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), + ConvStrideD_); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(IndexType{0}, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), + ConvStrideH_); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(IndexType{0}, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), + ConvStrideW_); const auto IDTildeSliceEnd = math::min( DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1); @@ -1491,14 +1514,12 @@ struct TransformConvBwdDataToGemm_v1 static_assert(CTranspose == false); // only work on HTilde and WTilde that contribute to non-padding area of input // tensor - const auto IHTildeSliceBegin = - math::integer_divide_floor(math::max(static_cast(I0), - InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), - ConvStrideH_); - const auto IWTildeSliceBegin = - math::integer_divide_floor(math::max(static_cast(I0), - InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), - ConvStrideW_); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(IndexType{0}, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), + ConvStrideH_); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(IndexType{0}, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), + ConvStrideW_); const auto IHTildeSliceEnd = math::min( HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); @@ -1831,7 +1852,7 @@ struct TransformConvBwdDataToGemm_v1 IndexType ZTilde_, YTilde_, XTilde_; IndexType DTilde_, HTilde_, WTilde_; IndexType ZDot_, YDot_, XDot_; - index_t batch_k_; + IndexType batch_k_; }; } // namespace tensor_operation diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index eb06ce4120..34a6609d99 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -26,8 +26,9 @@ struct TransformConvBwdWeightToGemm static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; + template + using NumberType = + std::conditional_t, Number, LongNumber>; template ::type = false> constexpr static auto @@ -38,7 +39,7 @@ struct TransformConvBwdWeightToGemm const std::array& output_strides) { const IndexType WoStride = output_strides[4]; - const auto KStride = Number<1>{}; + const auto KStride = NumberType<1>{}; return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), make_tuple(WoStride, KStride)); } @@ -76,7 +77,7 @@ struct TransformConvBwdWeightToGemm const IndexType C, const std::array& weights_strides) { - const auto CStride = Number<1>{}; + const auto CStride = NumberType<1>{}; const auto KStride = weights_strides[1]; return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride)); } @@ -91,7 +92,7 @@ struct TransformConvBwdWeightToGemm const std::array& output_strides) { const IndexType WoStride = output_strides[5]; - const auto KStride = Number<1>{}; + const auto KStride = NumberType<1>{}; return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), make_tuple(WoStride, KStride)); } @@ -133,7 +134,7 @@ struct TransformConvBwdWeightToGemm const IndexType C, const std::array& weights_strides) { - const auto CStride = Number<1>{}; + const auto CStride = NumberType<1>{}; const auto KStride = weights_strides[1]; return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), make_tuple(KStride, CStride)); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index a077a0e8c6..9d9526623c 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -37,8 +37,9 @@ struct TransformConvBwdWeightToGemmV2 static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; + template + using NumberType = + std::conditional_t, Number, LongNumber>; template ::type = false> constexpr static auto @@ -49,7 +50,7 @@ struct TransformConvBwdWeightToGemmV2 { const auto BatchStride = output_strides[0]; const auto WoStride = output_strides[3]; - const auto KStride = Number<1>{}; + const auto KStride = NumberType<1>{}; return make_naive_tensor_descriptor(make_tuple(N * Wo, NumGroupsToMerge, K), make_tuple(WoStride, BatchStride, KStride)); } @@ -86,7 +87,7 @@ struct TransformConvBwdWeightToGemmV2 const IndexType C, const std::array& weights_strides) { - const auto CStride = Number<1>{}; + const auto CStride = NumberType<1>{}; const auto KStride = weights_strides[1]; const auto XStride = weights_strides[3]; const auto BatchStride = weights_strides[0]; @@ -138,7 +139,7 @@ struct TransformConvBwdWeightToGemmV2 { const auto BatchStride = output_strides[0]; const auto WoStride = output_strides[4]; - const auto KStride = Number<1>{}; + const auto KStride = NumberType<1>{}; return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumGroupsToMerge, K), make_tuple(WoStride, BatchStride, KStride)); } @@ -178,7 +179,7 @@ struct TransformConvBwdWeightToGemmV2 const IndexType C, const std::array& weights_strides) { - const auto CStride = Number<1>{}; + const auto CStride = NumberType<1>{}; const auto KStride = weights_strides[1]; const auto XStride = weights_strides[4]; const auto BatchStride = weights_strides[0]; @@ -231,7 +232,7 @@ struct TransformConvBwdWeightToGemmV2 { const auto BatchStride = output_strides[0]; const auto WoStride = output_strides[5]; - const auto KStride = Number<1>{}; + const auto KStride = NumberType<1>{}; return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumGroupsToMerge, K), make_tuple(WoStride, BatchStride, KStride)); } @@ -274,7 +275,7 @@ struct TransformConvBwdWeightToGemmV2 const IndexType C, const std::array& weights_strides) { - const auto CStride = Number<1>{}; + const auto CStride = NumberType<1>{}; const auto KStride = weights_strides[1]; const auto XStride = weights_strides[5]; const auto BatchStride = weights_strides[0]; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index c114b90ee2..61efd2c0cb 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -23,12 +23,16 @@ template {}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; + template + using NumberType = + std::conditional_t, Number, LongNumber>; + + static constexpr auto I0 = NumberType<0>{}; + static constexpr auto I1 = NumberType<1>{}; + static constexpr auto I2 = NumberType<2>{}; + static constexpr auto I3 = NumberType<3>{}; + static constexpr auto I4 = NumberType<4>{}; + static constexpr auto I5 = NumberType<5>{}; template static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, @@ -553,7 +557,7 @@ struct TransformConvFwdToGemm const auto in_n_x_wo_c_desc = transform_tensor_descriptor( in_n_wip_c_desc, make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_embed_transform(make_tuple(NumberType<3>{}, Wo_), make_tuple(ConvDilationW_, ConvStrideW_))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{})); @@ -561,7 +565,7 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_x_wo_c_desc, make_tuple(make_merge_transform(make_tuple(N_, Wo_)), - make_pass_through_transform(Number<3>{})), + make_pass_through_transform(NumberType<3>{})), make_tuple(Sequence<0, 2>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -582,7 +586,7 @@ struct TransformConvFwdToGemm const auto in_n_x_wo_c_desc = transform_tensor_descriptor( in_n_wip_c_desc, make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_embed_transform(make_tuple(NumberType<3>{}, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), @@ -591,7 +595,7 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_x_wo_c_desc, make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), - make_pass_through_transform(Number<3>{})), + make_pass_through_transform(NumberType<3>{})), make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -769,9 +773,9 @@ struct TransformConvFwdToGemm const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_embed_transform(make_tuple(NumberType<3>{}, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_embed_transform(make_tuple(NumberType<3>{}, Wo_), make_tuple(ConvDilationW_, ConvStrideW_))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{})); @@ -779,7 +783,7 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_y_ho_x_wo_c_desc, make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), - make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_merge_transform(make_tuple(NumberType<3>{}, NumberType<3>{}))), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -801,9 +805,9 @@ struct TransformConvFwdToGemm const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( in_n_hip_wip_groups_c_desc, make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_embed_transform(make_tuple(NumberType<3>{}, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_embed_transform(make_tuple(NumberType<3>{}, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), @@ -812,7 +816,7 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_y_ho_x_wo_groups_c_desc, make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), - make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_merge_transform(make_tuple(NumberType<3>{}, NumberType<3>{}))), make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -1017,11 +1021,11 @@ struct TransformConvFwdToGemm const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Number<3>{}, Do_), + make_embed_transform(make_tuple(NumberType<3>{}, Do_), make_tuple(ConvDilationD_, ConvStrideD_)), - make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_embed_transform(make_tuple(NumberType<3>{}, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_embed_transform(make_tuple(NumberType<3>{}, Wo_), make_tuple(ConvDilationW_, ConvStrideW_))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( @@ -1029,9 +1033,9 @@ struct TransformConvFwdToGemm return transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_desc, - make_tuple( - make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), - make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform( + make_tuple(NumberType<3>{}, NumberType<3>{}, NumberType<3>{}))), make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -1056,11 +1060,11 @@ struct TransformConvFwdToGemm const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Number<3>{}, Do_), + make_embed_transform(make_tuple(NumberType<3>{}, Do_), make_tuple(ConvDilationD_, ConvStrideD_)), - make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_embed_transform(make_tuple(NumberType<3>{}, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_embed_transform(make_tuple(NumberType<3>{}, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), make_pass_through_transform(NumGroupsToMerge)), make_tuple( @@ -1075,7 +1079,8 @@ struct TransformConvFwdToGemm in_n_z_do_y_ho_x_wo_c_desc, make_tuple( make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), - make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_merge_transform( + make_tuple(NumberType<3>{}, NumberType<3>{}, NumberType<3>{}))), make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -1348,10 +1353,10 @@ struct TransformConvFwdToGemm if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter3x3) { - using FilterSizeNumType = - ck::conditional_t, - ck::conditional_t, Number<27>>>; + using FilterSizeNumType = ck::conditional_t< + NDimSpatial == 1, + NumberType<3>, + ck::conditional_t, NumberType<27>>>; if constexpr(NumGroupsToMerge == 1) { diff --git a/include/ck/utility/number.hpp b/include/ck/utility/number.hpp index a8e2dcbddb..30c248cf29 100644 --- a/include/ck/utility/number.hpp +++ b/include/ck/utility/number.hpp @@ -14,5 +14,38 @@ using Number = integral_constant; template using LongNumber = integral_constant; +// --------------------------------------------------------------------------- +// is_number -- true if T is a specialization of Number (integral_constant) +// --------------------------------------------------------------------------- +template +struct is_number : false_type +{ +}; + +template +struct is_number> : true_type +{ +}; + +template +inline constexpr bool is_number_v = is_number::value; + +// --------------------------------------------------------------------------- +// is_long_number -- true if T is a specialization of LongNumber (integral_constant) +// --------------------------------------------------------------------------- +template +struct is_long_number : false_type +{ +}; + +template +struct is_long_number> : true_type +{ +}; + +template +inline constexpr bool is_long_number_v = is_long_number::value; + } // namespace ck #endif diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp index 5210265cef..601c24c6a9 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp @@ -15,50 +15,66 @@ namespace ck { namespace ref { -// Optimized backward data convolution kernel working with packed (contiguous) tensors with -// multi-ABD support Computes gradients w.r.t. input from output gradients and weights Assumes -// row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], output[G][N][K][spatial] template -__global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in, - const WeiDataType* const* __restrict__ p_weis, - const OutDataType* const* __restrict__ p_outs, - const DDataType* const* __restrict__ p_ds, - const index_t* const* __restrict__ p_d_strides, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void +naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in, + const WeiDataType* const* __restrict__ p_weis, + const OutDataType* const* __restrict__ p_outs, + const DDataType* const* __restrict__ p_ds, + const long_index_t* const* __restrict__ p_d_strides, + long_index_t G, + long_index_t N, + long_index_t K, + long_index_t C, + long_index_t Di, + long_index_t Hi, + long_index_t Wi, + long_index_t Z, + long_index_t Y, + long_index_t X, + long_index_t Do, + long_index_t Ho, + long_index_t Wo, + long_index_t stride_z, + long_index_t stride_y, + long_index_t stride_x, + long_index_t dilation_z, + long_index_t dilation_y, + long_index_t dilation_x, + long_index_t pad_z, + long_index_t pad_y, + long_index_t pad_x, + long_index_t in_sg, + long_index_t in_sn, + long_index_t in_sc, + long_index_t in_sd, + long_index_t in_sh, + long_index_t in_sw, + long_index_t wei_sg, + long_index_t wei_sk, + long_index_t wei_sc, + long_index_t wei_sz, + long_index_t wei_sy, + long_index_t wei_sx, + long_index_t out_sg, + long_index_t out_sn, + long_index_t out_sk, + long_index_t out_sd, + long_index_t out_sh, + long_index_t out_sw, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -69,31 +85,21 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ if constexpr(NDimSpatial == 1) { - const long_index_t num_in = G * N * C * Wi; - const long_index_t out_stride_g = N * K * Wo; - const long_index_t out_stride_n = K * Wo; - const long_index_t out_stride_k = Wo; - const long_index_t wei_stride_g = K * C * X; - const long_index_t wei_stride_k = C * X; - const long_index_t wei_stride_c = X; - const long_index_t in_stride_g = N * C * Wi; - const long_index_t in_stride_n = C * Wi; - const long_index_t in_stride_c = Wi; + const long_index_t num_in = G * N * C * Wi; for(long_index_t idx = tid; idx < num_in; idx += num_threads) { - index_t remaining = idx; - const index_t wi = remaining % Wi; + long_index_t remaining = idx; + const long_index_t wi = remaining % Wi; remaining /= Wi; - const index_t c = remaining % C; + const long_index_t c = remaining % C; remaining /= C; - const index_t n = remaining % N; - const index_t g = remaining / N; + const long_index_t n = remaining % N; + const long_index_t g = remaining / N; - float acc = 0.0f; - // Base pointers for current group and batch - const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; - const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; + float acc = 0.0f; + const OutDataType* output_grad_g_n = p_outs[0] + g * out_sg + n * out_sn; + const WeiDataType* weight_g = p_weis[0] + g * wei_sg; for(index_t x = 0; x < X; ++x) { @@ -103,29 +109,26 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ long_index_t wo = w_tmp / stride_x; if(wo >= 0 && wo < Wo) { - // Pointers at current filter position const OutDataType* output_grad_g_n_k = output_grad_g_n; - const WeiDataType* weight_g_k_c = weight_g + c * wei_stride_c; + const WeiDataType* weight_g_k_c = weight_g + c * wei_sc; for(index_t k = 0; k < K; ++k) { - // Handle output gradient element-wise operation with extra A tensors detail::apply_multi_tensor_elementwise_op( out_val, out_op, output_grad_g_n_k, p_outs + 1, - g * out_stride_g + n * out_stride_n, - k * out_stride_k + wo); + g * out_sg + n * out_sn, + k * out_sk + wo * out_sw); - // Handle weight element-wise operation with extra B tensors detail::apply_multi_tensor_elementwise_op( wei_val, wei_op, weight_g_k_c, p_weis + 1, - g * wei_stride_g + c * wei_stride_c, - k * wei_stride_k + x); + g * wei_sg + c * wei_sc, + k * wei_sk + x * wei_sx); acc += type_convert(out_val) * type_convert(wei_val); } @@ -136,41 +139,28 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ detail::apply_d_tensor_elementwise_op( in_val, in_op, acc, p_ds, p_d_strides, g, n, c, wi); - p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val; + p_in[g * in_sg + n * in_sn + c * in_sc + wi * in_sw] = in_val; } } else if constexpr(NDimSpatial == 2) { - const long_index_t num_in = G * N * C * Hi * Wi; - const long_index_t out_stride_g = N * K * Ho * Wo; - const long_index_t out_stride_n = K * Ho * Wo; - const long_index_t out_stride_k = Ho * Wo; - const long_index_t out_stride_h = Wo; - const long_index_t wei_stride_g = K * C * Y * X; - const long_index_t wei_stride_k = C * Y * X; - const long_index_t wei_stride_c = Y * X; - const long_index_t wei_stride_y = X; - const long_index_t in_stride_g = N * C * Hi * Wi; - const long_index_t in_stride_n = C * Hi * Wi; - const long_index_t in_stride_c = Hi * Wi; - const long_index_t in_stride_h = Wi; + const long_index_t num_in = G * N * C * Hi * Wi; for(long_index_t idx = tid; idx < num_in; idx += num_threads) { - index_t remaining = idx; - const index_t wi = remaining % Wi; + long_index_t remaining = idx; + const long_index_t wi = remaining % Wi; remaining /= Wi; - const index_t hi = remaining % Hi; + const long_index_t hi = remaining % Hi; remaining /= Hi; - const index_t c = remaining % C; + const long_index_t c = remaining % C; remaining /= C; - const index_t n = remaining % N; - const index_t g = remaining / N; + const long_index_t n = remaining % N; + const long_index_t g = remaining / N; - float acc = 0.0f; - // Base pointers for current group and batch - const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; - const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; + float acc = 0.0f; + const OutDataType* output_grad_g_n = p_outs[0] + g * out_sg + n * out_sn; + const WeiDataType* weight_g = p_weis[0] + g * wei_sg; for(index_t y = 0; y < Y; ++y) { @@ -180,10 +170,8 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - // Pointers at current spatial height and filter Y position - const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_stride_h; - const WeiDataType* weight_at_c_y = - weight_g + c * wei_stride_c + y * wei_stride_y; + const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_sh; + const WeiDataType* weight_at_c_y = weight_g + c * wei_sc + y * wei_sy; for(index_t x = 0; x < X; ++x) { @@ -195,24 +183,21 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ { for(index_t k = 0; k < K; ++k) { - // Handle output gradient element-wise operation with extra - // A tensors detail::apply_multi_tensor_elementwise_op( out_val, out_op, output_grad_at_h, p_outs + 1, - g * out_stride_g + n * out_stride_n + ho * out_stride_h, - k * out_stride_k + wo); + g * out_sg + n * out_sn + ho * out_sh, + k * out_sk + wo * out_sw); - // Handle weight element-wise operation with extra B tensors detail::apply_multi_tensor_elementwise_op( wei_val, wei_op, weight_at_c_y, p_weis + 1, - g * wei_stride_g + c * wei_stride_c + y * wei_stride_y, - k * wei_stride_k + x); + g * wei_sg + c * wei_sc + y * wei_sy, + k * wei_sk + x * wei_sx); acc += type_convert(out_val) * type_convert(wei_val); @@ -235,47 +220,30 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ hi * p_d_strides[0][3] + wi * p_d_strides[0][4]); - p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] = - in_val; + p_in[g * in_sg + n * in_sn + c * in_sc + hi * in_sh + wi * in_sw] = in_val; } } else if constexpr(NDimSpatial == 3) { - const long_index_t num_in = G * N * C * Di * Hi * Wi; - const long_index_t out_stride_g = N * K * Do * Ho * Wo; - const long_index_t out_stride_n = K * Do * Ho * Wo; - const long_index_t out_stride_k = Do * Ho * Wo; - const long_index_t out_stride_d = Ho * Wo; - const long_index_t out_stride_h = Wo; - const long_index_t wei_stride_g = K * C * Z * Y * X; - const long_index_t wei_stride_k = C * Z * Y * X; - const long_index_t wei_stride_c = Z * Y * X; - const long_index_t wei_stride_z = Y * X; - const long_index_t wei_stride_y = X; - const long_index_t in_stride_g = N * C * Di * Hi * Wi; - const long_index_t in_stride_n = C * Di * Hi * Wi; - const long_index_t in_stride_c = Di * Hi * Wi; - const long_index_t in_stride_d = Hi * Wi; - const long_index_t in_stride_h = Wi; + const long_index_t num_in = G * N * C * Di * Hi * Wi; for(long_index_t idx = tid; idx < num_in; idx += num_threads) { - index_t remaining = idx; - const index_t wi = remaining % Wi; + long_index_t remaining = idx; + const long_index_t wi = remaining % Wi; remaining /= Wi; - const index_t hi = remaining % Hi; + const long_index_t hi = remaining % Hi; remaining /= Hi; - const index_t di = remaining % Di; + const long_index_t di = remaining % Di; remaining /= Di; - const index_t c = remaining % C; + const long_index_t c = remaining % C; remaining /= C; - const index_t n = remaining % N; - const index_t g = remaining / N; + const long_index_t n = remaining % N; + const long_index_t g = remaining / N; - float acc = 0.0f; - // Base pointers for current group and batch - const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; - const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; + float acc = 0.0f; + const OutDataType* output_grad_g_n = p_outs[0] + g * out_sg + n * out_sn; + const WeiDataType* weight_g = p_weis[0] + g * wei_sg; for(index_t z = 0; z < Z; ++z) { @@ -285,11 +253,8 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ long_index_t do_idx = d_tmp / stride_z; if(do_idx >= 0 && do_idx < Do) { - // Pointers at current spatial depth - const OutDataType* output_grad_at_d = - output_grad_g_n + do_idx * out_stride_d; - const WeiDataType* weight_at_c_z = - weight_g + c * wei_stride_c + z * wei_stride_z; + const OutDataType* output_grad_at_d = output_grad_g_n + do_idx * out_sd; + const WeiDataType* weight_at_c_z = weight_g + c * wei_sc + z * wei_sz; for(index_t y = 0; y < Y; ++y) { @@ -299,11 +264,9 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - // Pointers at current spatial depth and height const OutDataType* output_grad_at_d_h = - output_grad_at_d + ho * out_stride_h; - const WeiDataType* weight_at_c_z_y = - weight_at_c_z + y * wei_stride_y; + output_grad_at_d + ho * out_sh; + const WeiDataType* weight_at_c_z_y = weight_at_c_z + y * wei_sy; for(index_t x = 0; x < X; ++x) { @@ -315,30 +278,24 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ { for(index_t k = 0; k < K; ++k) { - // Handle output gradient element-wise operation - // with extra A tensors detail::apply_multi_tensor_elementwise_op< NumAExtra>(out_val, out_op, output_grad_at_d_h, p_outs + 1, - g * out_stride_g + - n * out_stride_n + - do_idx * out_stride_d + - ho * out_stride_h, - k * out_stride_k + wo); + g * out_sg + n * out_sn + + do_idx * out_sd + + ho * out_sh, + k * out_sk + wo * out_sw); - // Handle weight element-wise operation with - // extra B tensors detail::apply_multi_tensor_elementwise_op< - NumBExtra>( - wei_val, - wei_op, - weight_at_c_z_y, - p_weis + 1, - g * wei_stride_g + c * wei_stride_c + - z * wei_stride_z + y * wei_stride_y, - k * wei_stride_k + x); + NumBExtra>(wei_val, + wei_op, + weight_at_c_z_y, + p_weis + 1, + g * wei_sg + c * wei_sc + + z * wei_sz + y * wei_sy, + k * wei_sk + x * wei_sx); acc += type_convert(out_val) * type_convert(wei_val); @@ -364,13 +321,11 @@ __global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_ c, di * p_d_strides[0][3] + hi * p_d_strides[0][4] + wi * p_d_strides[0][5]); - p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d + - hi * in_stride_h + wi] = in_val; + p_in[g * in_sg + n * in_sn + c * in_sc + di * in_sd + hi * in_sh + wi * in_sw] = in_val; } } } -// GPU reference backward data convolution with multi-ABD support - takes ConvParam directly template // D tensor type, defaults to TIn for backward compatibility + typename TD = TIn> void naive_conv_bwd_data_multi_abd( TIn* p_in, const std::array& p_weis, const std::array& p_outs, const std::array& p_ds, const ck::utils::conv::ConvParam& conv_param, - [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, - const std::array, NumDElementwise>& d_strides, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, InElementwiseOperation in_element_op = InElementwiseOperation{}, WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, OutElementwiseOperation out_element_op = OutElementwiseOperation{}, @@ -399,115 +354,35 @@ void naive_conv_bwd_data_multi_abd( { const auto ndim = conv_param.num_dim_spatial_; - const index_t G = conv_param.G_; - const index_t N = conv_param.N_; - const index_t C = conv_param.C_; - const index_t K = conv_param.K_; + const long_index_t G = conv_param.G_; + const long_index_t N = conv_param.N_; + const long_index_t C = conv_param.C_; + const long_index_t K = conv_param.K_; - std::vector in_lengths = {G, N, C}; - std::vector wei_lengths = {G, K, C}; - std::vector out_lengths = {G, N, K}; + std::vector in_lengths = {G, N, C}; + std::vector wei_lengths = {G, K, C}; + std::vector out_lengths = {G, N, K}; for(index_t i = 0; i < ndim; ++i) { - in_lengths.push_back(static_cast(conv_param.input_spatial_lengths_[i])); - wei_lengths.push_back(static_cast(conv_param.filter_spatial_lengths_[i])); - out_lengths.push_back(static_cast(conv_param.output_spatial_lengths_[i])); + in_lengths.push_back(static_cast(conv_param.input_spatial_lengths_[i])); + wei_lengths.push_back(static_cast(conv_param.filter_spatial_lengths_[i])); + out_lengths.push_back(static_cast(conv_param.output_spatial_lengths_[i])); } - // Calculate total elements for buffer allocation - long_index_t in_total = 1, wei_total = 1, out_total = 1; + long_index_t in_total = 1; for(auto l : in_lengths) in_total *= l; - for(auto l : wei_lengths) - wei_total *= l; - for(auto l : out_lengths) - out_total *= l; - // Allocate packed buffers - SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - - std::vector wei_packed_bufs; - wei_packed_bufs.reserve(NumBElementwise + 1); - for(index_t i = 0; i <= NumBElementwise; ++i) - { - wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); - } - - std::vector out_packed_bufs; - out_packed_bufs.reserve(NumAElementwise + 1); - for(index_t i = 0; i <= NumAElementwise; ++i) - { - out_packed_bufs.emplace_back(out_total * sizeof(TOut)); - } - - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); - - std::array p_weis_packed; - for(index_t i = 0; i <= NumBElementwise; ++i) - { - p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); - } - - std::array p_outs_packed; - for(index_t i = 0; i <= NumAElementwise; ++i) - { - p_outs_packed[i] = static_cast(out_packed_bufs[i].GetDeviceBuffer()); - } - - // Compute strides and allocate device arrays for pack/unpack - std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); - std::vector wei_strides = compute_conv_tensor_strides(wei_lengths, ndim); - std::vector out_strides = compute_conv_tensor_strides(out_lengths, ndim); - - const size_t dim_count = in_lengths.size(); - SimpleDeviceMem in_lengths_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem in_strides_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem wei_lengths_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem wei_strides_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem out_lengths_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem out_strides_buf(dim_count * sizeof(index_t)); - - index_t* d_in_lengths = static_cast(in_lengths_buf.GetDeviceBuffer()); - index_t* d_in_strides = static_cast(in_strides_buf.GetDeviceBuffer()); - index_t* d_wei_lengths = static_cast(wei_lengths_buf.GetDeviceBuffer()); - index_t* d_wei_strides = static_cast(wei_strides_buf.GetDeviceBuffer()); - index_t* d_out_lengths = static_cast(out_lengths_buf.GetDeviceBuffer()); - index_t* d_out_strides = static_cast(out_strides_buf.GetDeviceBuffer()); - - HIP_CHECK_ERROR(hipMemcpy( - d_in_lengths, in_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_in_strides, in_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_wei_lengths, wei_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_wei_strides, wei_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_out_lengths, out_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_out_strides, out_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - - // Pack output and weight tensors to contiguous layout (inputs to bwd data) - constexpr int block_size = 256; - - for(index_t i = 0; i <= NumAElementwise; ++i) - { - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_outs[i], p_outs_packed[i], d_out_lengths, d_out_strides, dim_count, out_total); - } - - for(index_t i = 0; i <= NumBElementwise; ++i) - { - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); - } + std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); + std::vector wei_strides = + compute_conv_tensor_strides(wei_lengths, ndim); + std::vector out_strides = + compute_conv_tensor_strides(out_lengths, ndim); // Prepare D tensor stride arrays on device std::vector d_stride_bufs; - std::array p_d_strides_dev = {}; + std::array p_d_strides_dev = {}; if constexpr(NumDElementwise > 0) { @@ -515,35 +390,32 @@ void naive_conv_bwd_data_multi_abd( for(index_t i = 0; i < NumDElementwise; ++i) { - d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); - p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(long_index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], d_strides[i].data(), - d_strides[i].size() * sizeof(index_t), + d_strides[i].size() * sizeof(long_index_t), hipMemcpyHostToDevice)); } } - // Create device arrays of pointers + // Create device pointer arrays (use original pointers directly, no packing) SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); SimpleDeviceMem outs_ptrs_buf((NumAElementwise + 1) * sizeof(TOut*)); SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); - SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(long_index_t*)); - TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); - TOut** d_outs_ptrs = static_cast(outs_ptrs_buf.GetDeviceBuffer()); - TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); - index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TOut** d_outs_ptrs = static_cast(outs_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + long_index_t** d_d_strides_ptrs = + static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); - HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, - p_weis_packed.data(), - (NumBElementwise + 1) * sizeof(TWei*), - hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_outs_ptrs, - p_outs_packed.data(), - (NumAElementwise + 1) * sizeof(TOut*), - hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy( + d_weis_ptrs, p_weis.data(), (NumBElementwise + 1) * sizeof(TWei*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy( + d_outs_ptrs, p_outs.data(), (NumAElementwise + 1) * sizeof(TOut*), hipMemcpyHostToDevice)); if constexpr(NumDElementwise > 0) { @@ -557,23 +429,51 @@ void naive_conv_bwd_data_multi_abd( d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, p_d_strides_dev.data(), - NumDElementwise * sizeof(index_t*), + NumDElementwise * sizeof(long_index_t*), hipMemcpyHostToDevice)); } - // Build conv parameter vectors for kernel invocation - std::vector conv_strides(ndim); - std::vector conv_dilations(ndim); - std::vector input_pads(ndim); + std::vector conv_strides(ndim); + std::vector conv_dilations(ndim); + std::vector input_pads(ndim); for(index_t i = 0; i < ndim; ++i) { - conv_strides[i] = static_cast(conv_param.conv_filter_strides_[i]); - conv_dilations[i] = static_cast(conv_param.conv_filter_dilations_[i]); - input_pads[i] = static_cast(conv_param.input_left_pads_[i]); + conv_strides[i] = static_cast(conv_param.conv_filter_strides_[i]); + conv_dilations[i] = static_cast(conv_param.conv_filter_dilations_[i]); + input_pads[i] = static_cast(conv_param.input_left_pads_[i]); } - // Run backward data convolution kernel on packed data - const int in_grid = (in_total + block_size - 1) / block_size; + // Extract strides indexed as [G,N,C,spatial...] and [G,K,C,spatial...] / [G,N,K,spatial...] + // in_strides: [0]=sg [1]=sn [2]=sc [3]=sd [4]=sh [5]=sw + // wei_strides: [0]=sg [1]=sk [2]=sc [3]=sz [4]=sy [5]=sx + // out_strides: [0]=sg [1]=sn [2]=sk [3]=sd [4]=sh [5]=sw + const long_index_t in_sg = in_strides[0]; + const long_index_t in_sn = in_strides[1]; + const long_index_t in_sc = in_strides[2]; + const long_index_t in_sd = (ndim >= 3) ? in_strides[3] : 0; + const long_index_t in_sh = (ndim >= 2) ? in_strides[ndim == 3 ? 4 : 3] : 0; + const long_index_t in_sw = in_strides[ndim == 3 ? 5 : (ndim == 2 ? 4 : 3)]; + + const long_index_t wei_sg = wei_strides[0]; + const long_index_t wei_sk = wei_strides[1]; + const long_index_t wei_sc = wei_strides[2]; + const long_index_t wei_sz = (ndim >= 3) ? wei_strides[3] : 0; + const long_index_t wei_sy = (ndim >= 2) ? wei_strides[ndim == 3 ? 4 : 3] : 0; + const long_index_t wei_sx = wei_strides[ndim == 3 ? 5 : (ndim == 2 ? 4 : 3)]; + + const long_index_t out_sg = out_strides[0]; + const long_index_t out_sn = out_strides[1]; + const long_index_t out_sk = out_strides[2]; + const long_index_t out_sd = (ndim >= 3) ? out_strides[3] : 0; + const long_index_t out_sh = (ndim >= 2) ? out_strides[ndim == 3 ? 4 : 3] : 0; + const long_index_t out_sw = out_strides[ndim == 3 ? 5 : (ndim == 2 ? 4 : 3)]; + + constexpr int block_size = 256; + const long_index_t in_grid_unclamped = (in_total + block_size - 1) / block_size; + // gridDim.x * blockDim.x must not overflow uint32_t; kernel uses a grid-stride loop. + constexpr long_index_t max_grid = + static_cast(std::numeric_limits::max()) / block_size; + const int in_grid = static_cast(std::min(in_grid_unclamped, max_grid)); if(ndim == 1) { @@ -588,7 +488,7 @@ void naive_conv_bwd_data_multi_abd( InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> - <<>>(p_in_packed, + <<>>(p_in, d_weis_ptrs, d_outs_ptrs, d_ds_ptrs, @@ -615,6 +515,24 @@ void naive_conv_bwd_data_multi_abd( 0, 0, input_pads[0], + in_sg, + in_sn, + in_sc, + in_sd, + in_sh, + in_sw, + wei_sg, + wei_sk, + wei_sc, + wei_sz, + wei_sy, + wei_sx, + out_sg, + out_sn, + out_sk, + out_sd, + out_sh, + out_sw, in_element_op, wei_element_op, out_element_op); @@ -632,7 +550,7 @@ void naive_conv_bwd_data_multi_abd( InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> - <<>>(p_in_packed, + <<>>(p_in, d_weis_ptrs, d_outs_ptrs, d_ds_ptrs, @@ -659,6 +577,24 @@ void naive_conv_bwd_data_multi_abd( 0, input_pads[0], input_pads[1], + in_sg, + in_sn, + in_sc, + in_sd, + in_sh, + in_sw, + wei_sg, + wei_sk, + wei_sc, + wei_sz, + wei_sy, + wei_sx, + out_sg, + out_sn, + out_sk, + out_sd, + out_sh, + out_sw, in_element_op, wei_element_op, out_element_op); @@ -676,7 +612,7 @@ void naive_conv_bwd_data_multi_abd( InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> - <<>>(p_in_packed, + <<>>(p_in, d_weis_ptrs, d_outs_ptrs, d_ds_ptrs, @@ -703,21 +639,32 @@ void naive_conv_bwd_data_multi_abd( input_pads[0], input_pads[1], input_pads[2], + in_sg, + in_sn, + in_sc, + in_sd, + in_sh, + in_sw, + wei_sg, + wei_sk, + wei_sc, + wei_sz, + wei_sy, + wei_sx, + out_sg, + out_sn, + out_sk, + out_sd, + out_sh, + out_sw, in_element_op, wei_element_op, out_element_op); } - // Unpack result back to strided layout - strided_copy_kernel<<>>( - p_in_packed, p_in, d_in_lengths, d_in_strides, dim_count, in_total); - HIP_CHECK_ERROR(hipGetLastError()); - - // Memory automatically freed by SimpleDeviceMem destructors } -// Original naive_conv_bwd_data - now a zero-overhead wrapper template p_weis = {p_wei}; - std::array p_outs = {p_out}; - std::array p_ds = {}; - std::array, 0> d_lengths = {}; - std::array, 0> d_strides = {}; + std::array p_weis = {p_wei}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; naive_conv_bwd_data_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_in, p_weis, diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp index 8cee2e2b77..915c0830c2 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp @@ -36,29 +36,47 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i WeiDataType* __restrict__ p_wei_grad, const OutDataType* const* __restrict__ p_out_grads, const DDataType* const* __restrict__ p_ds, - const index_t* const* __restrict__ p_d_strides, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, + const long_index_t* const* __restrict__ p_d_strides, + long_index_t G, + long_index_t N, + long_index_t K, + long_index_t C, + long_index_t Di, + long_index_t Hi, + long_index_t Wi, + long_index_t Z, + long_index_t Y, + long_index_t X, + long_index_t Do, + long_index_t Ho, + long_index_t Wo, + long_index_t stride_z, + long_index_t stride_y, + long_index_t stride_x, + long_index_t dilation_z, + long_index_t dilation_y, + long_index_t dilation_x, + long_index_t pad_z, + long_index_t pad_y, + long_index_t pad_x, + long_index_t in_sg, + long_index_t in_sn, + long_index_t in_sc, + long_index_t in_sd, + long_index_t in_sh, + long_index_t in_sw, + long_index_t out_sg, + long_index_t out_sn, + long_index_t out_sk, + long_index_t out_sd, + long_index_t out_sh, + long_index_t out_sw, + long_index_t wei_sg, + long_index_t wei_sk, + long_index_t wei_sc, + long_index_t wei_sz, + long_index_t wei_sy, + long_index_t wei_sx, InElementOp in_op, WeiElementOp wei_op, OutElementOp out_op) @@ -73,25 +91,22 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i if constexpr(NDimSpatial == 1) { const long_index_t num_wei = G * K * C * X; - const long_index_t in_stride_g = N * C * Wi; - const long_index_t in_stride_n = C * Wi; - const long_index_t in_stride_c = Wi; - const long_index_t out_stride_g = N * K * Wo; - const long_index_t out_stride_n = K * Wo; - const long_index_t out_stride_k = Wo; - const long_index_t wei_stride_g = K * C * X; - const long_index_t wei_stride_k = C * X; - const long_index_t wei_stride_c = X; + const long_index_t in_stride_g = in_sg; + const long_index_t in_stride_n = in_sn; + const long_index_t in_stride_c = in_sc; + const long_index_t out_stride_g = out_sg; + const long_index_t out_stride_n = out_sn; + const long_index_t out_stride_k = out_sk; for(long_index_t idx = tid; idx < num_wei; idx += num_threads) { - index_t remaining = idx; - const index_t x = remaining % X; + long_index_t remaining = idx; + const long_index_t x = remaining % X; remaining /= X; - const index_t c = remaining % C; + const long_index_t c = remaining % C; remaining /= C; - const index_t k = remaining % K; - const index_t g = remaining / K; + const long_index_t k = remaining % K; + const long_index_t g = remaining / K; float acc = 0.0f; // Base pointers for current group @@ -99,14 +114,14 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions - for(index_t n = 0; n < N; ++n) + for(long_index_t n = 0; n < N; ++n) { // Pointers at current batch and input channel const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; const OutDataType* output_grad_at_n_k = output_grad_g + n * out_stride_n + k * out_stride_k; - for(index_t wo = 0; wo < Wo; ++wo) + for(long_index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) @@ -118,7 +133,7 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i input_at_n_c, p_ins + 1, g * in_stride_g + n * in_stride_n + c * in_stride_c, - wi); + wi * in_sw); // Handle output gradient element-wise operation with extra B tensors detail::apply_multi_tensor_elementwise_op( @@ -127,7 +142,7 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i output_grad_at_n_k, p_out_grads + 1, g * out_stride_g + n * out_stride_n + k * out_stride_k, - wo); + wo * out_sw); acc += type_convert(out_val) * type_convert(in_val); } @@ -137,36 +152,32 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i detail::apply_d_tensor_elementwise_op( wei_val, wei_op, acc, p_ds, p_d_strides, g, k, c, x); - p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val; + p_wei_grad[g * wei_sg + k * wei_sk + c * wei_sc + x * wei_sx] = wei_val; } } else if constexpr(NDimSpatial == 2) { const long_index_t num_wei = G * K * C * Y * X; - const long_index_t in_stride_g = N * C * Hi * Wi; - const long_index_t in_stride_n = C * Hi * Wi; - const long_index_t in_stride_c = Hi * Wi; - const long_index_t in_stride_h = Wi; - const long_index_t out_stride_g = N * K * Ho * Wo; - const long_index_t out_stride_n = K * Ho * Wo; - const long_index_t out_stride_k = Ho * Wo; - const long_index_t out_stride_h = Wo; - const long_index_t wei_stride_g = K * C * Y * X; - const long_index_t wei_stride_k = C * Y * X; - const long_index_t wei_stride_c = Y * X; - const long_index_t wei_stride_y = X; + const long_index_t in_stride_g = in_sg; + const long_index_t in_stride_n = in_sn; + const long_index_t in_stride_c = in_sc; + const long_index_t in_stride_h = in_sh; + const long_index_t out_stride_g = out_sg; + const long_index_t out_stride_n = out_sn; + const long_index_t out_stride_k = out_sk; + const long_index_t out_stride_h = out_sh; for(long_index_t idx = tid; idx < num_wei; idx += num_threads) { - index_t remaining = idx; - const index_t x = remaining % X; + long_index_t remaining = idx; + const long_index_t x = remaining % X; remaining /= X; - const index_t y = remaining % Y; + const long_index_t y = remaining % Y; remaining /= Y; - const index_t c = remaining % C; + const long_index_t c = remaining % C; remaining /= C; - const index_t k = remaining % K; - const index_t g = remaining / K; + const long_index_t k = remaining % K; + const long_index_t g = remaining / K; float acc = 0.0f; // Base pointers for current group @@ -204,7 +215,7 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i p_ins + 1, g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h, - wi); + wi * in_sw); // Handle output gradient element-wise operation with extra B // tensors @@ -215,7 +226,7 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i p_out_grads + 1, g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h, - wo); + wo * out_sw); acc += type_convert(out_val) * type_convert(in_val); } @@ -235,42 +246,36 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i y * p_d_strides[0][3] + x * p_d_strides[0][4]); - p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y + - x] = wei_val; + p_wei_grad[g * wei_sg + k * wei_sk + c * wei_sc + y * wei_sy + x * wei_sx] = wei_val; } } else if constexpr(NDimSpatial == 3) { const long_index_t num_wei = G * K * C * Z * Y * X; - const long_index_t in_stride_g = N * C * Di * Hi * Wi; - const long_index_t in_stride_n = C * Di * Hi * Wi; - const long_index_t in_stride_c = Di * Hi * Wi; - const long_index_t in_stride_d = Hi * Wi; - const long_index_t in_stride_h = Wi; - const long_index_t out_stride_g = N * K * Do * Ho * Wo; - const long_index_t out_stride_n = K * Do * Ho * Wo; - const long_index_t out_stride_k = Do * Ho * Wo; - const long_index_t out_stride_d = Ho * Wo; - const long_index_t out_stride_h = Wo; - const long_index_t wei_stride_g = K * C * Z * Y * X; - const long_index_t wei_stride_k = C * Z * Y * X; - const long_index_t wei_stride_c = Z * Y * X; - const long_index_t wei_stride_z = Y * X; - const long_index_t wei_stride_y = X; + const long_index_t in_stride_g = in_sg; + const long_index_t in_stride_n = in_sn; + const long_index_t in_stride_c = in_sc; + const long_index_t in_stride_d = in_sd; + const long_index_t in_stride_h = in_sh; + const long_index_t out_stride_g = out_sg; + const long_index_t out_stride_n = out_sn; + const long_index_t out_stride_k = out_sk; + const long_index_t out_stride_d = out_sd; + const long_index_t out_stride_h = out_sh; for(long_index_t idx = tid; idx < num_wei; idx += num_threads) { - index_t remaining = idx; - const index_t x = remaining % X; + long_index_t remaining = idx; + const long_index_t x = remaining % X; remaining /= X; - const index_t y = remaining % Y; + const long_index_t y = remaining % Y; remaining /= Y; - const index_t z = remaining % Z; + const long_index_t z = remaining % Z; remaining /= Z; - const index_t c = remaining % C; + const long_index_t c = remaining % C; remaining /= C; - const index_t k = remaining % K; - const index_t g = remaining / K; + const long_index_t k = remaining % K; + const long_index_t g = remaining / K; float acc = 0.0f; // Base pointers for current group @@ -318,7 +323,7 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i p_ins + 1, g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d + hi * in_stride_h, - wi); + wi * in_sw); // Handle output gradient element-wise operation with extra // B tensors @@ -329,7 +334,7 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i p_out_grads + 1, g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d + ho * out_stride_h, - wo); + wo * out_sw); acc += type_convert(out_val) * type_convert(in_val); @@ -352,8 +357,8 @@ naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_i c, z * p_d_strides[0][3] + y * p_d_strides[0][4] + x * p_d_strides[0][5]); - p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z + - y * wei_stride_y + x] = wei_val; + p_wei_grad[g * wei_sg + k * wei_sk + c * wei_sc + z * wei_sz + y * wei_sy + + x * wei_sx] = wei_val; } } } @@ -378,8 +383,8 @@ void naive_conv_bwd_weight_multi_abd( const std::array& p_outs, const std::array& p_ds, const ck::utils::conv::ConvParam& conv_param, - [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, - const std::array, NumDElementwise>& d_strides, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, InElementwiseOperation in_element_op = InElementwiseOperation{}, WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, OutElementwiseOperation out_element_op = OutElementwiseOperation{}, @@ -387,120 +392,37 @@ void naive_conv_bwd_weight_multi_abd( { const auto ndim = conv_param.num_dim_spatial_; - const index_t G = conv_param.G_; - const index_t N = conv_param.N_; - const index_t C = conv_param.C_; - const index_t K = conv_param.K_; + const long_index_t G = conv_param.G_; + const long_index_t N = conv_param.N_; + const long_index_t C = conv_param.C_; + const long_index_t K = conv_param.K_; - std::vector in_lengths = {G, N, C}; - std::vector wei_lengths = {G, K, C}; - std::vector out_lengths = {G, N, K}; + std::vector in_lengths = {G, N, C}; + std::vector wei_lengths = {G, K, C}; + std::vector out_lengths = {G, N, K}; for(index_t i = 0; i < ndim; ++i) { - in_lengths.push_back(static_cast(conv_param.input_spatial_lengths_[i])); - wei_lengths.push_back(static_cast(conv_param.filter_spatial_lengths_[i])); - out_lengths.push_back(static_cast(conv_param.output_spatial_lengths_[i])); + in_lengths.push_back(static_cast(conv_param.input_spatial_lengths_[i])); + wei_lengths.push_back(static_cast(conv_param.filter_spatial_lengths_[i])); + out_lengths.push_back(static_cast(conv_param.output_spatial_lengths_[i])); } - // Calculate total elements for buffer allocation - long_index_t in_total = 1, wei_total = 1, out_total = 1; - for(auto l : in_lengths) - in_total *= l; + // Calculate total elements for grid size + long_index_t wei_total = 1; for(auto l : wei_lengths) wei_total *= l; - for(auto l : out_lengths) - out_total *= l; - // Allocate packed buffers - std::vector in_packed_bufs; - in_packed_bufs.reserve(NumAElementwise + 1); - for(index_t i = 0; i <= NumAElementwise; ++i) - { - in_packed_bufs.emplace_back(in_total * sizeof(TIn)); - } - - SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei)); - - std::vector out_grad_packed_bufs; - out_grad_packed_bufs.reserve(NumBElementwise + 1); - for(index_t i = 0; i <= NumBElementwise; ++i) - { - out_grad_packed_bufs.emplace_back(out_total * sizeof(TOut)); - } - - std::array p_ins_packed; - for(index_t i = 0; i <= NumAElementwise; ++i) - { - p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); - } - - TWei* p_wei_grad_packed = static_cast(wei_grad_packed_buf.GetDeviceBuffer()); - - std::array p_out_grads_packed; - for(index_t i = 0; i <= NumBElementwise; ++i) - { - p_out_grads_packed[i] = static_cast(out_grad_packed_bufs[i].GetDeviceBuffer()); - } - - // Compute strides and allocate device arrays for pack/unpack - std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); - std::vector wei_strides = compute_conv_tensor_strides(wei_lengths, ndim); - std::vector out_strides = compute_conv_tensor_strides(out_lengths, ndim); - - const size_t dim_count = in_lengths.size(); - SimpleDeviceMem in_lengths_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem in_strides_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem wei_lengths_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem wei_strides_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem out_lengths_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem out_strides_buf(dim_count * sizeof(index_t)); - - index_t* d_in_lengths = static_cast(in_lengths_buf.GetDeviceBuffer()); - index_t* d_in_strides = static_cast(in_strides_buf.GetDeviceBuffer()); - index_t* d_wei_lengths = static_cast(wei_lengths_buf.GetDeviceBuffer()); - index_t* d_wei_strides = static_cast(wei_strides_buf.GetDeviceBuffer()); - index_t* d_out_lengths = static_cast(out_lengths_buf.GetDeviceBuffer()); - index_t* d_out_strides = static_cast(out_strides_buf.GetDeviceBuffer()); - - HIP_CHECK_ERROR(hipMemcpy( - d_in_lengths, in_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_in_strides, in_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_wei_lengths, wei_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_wei_strides, wei_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_out_lengths, out_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_out_strides, out_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - - // Pack input and output_grad tensors to contiguous layout (inputs to bwd weight) - constexpr int block_size = 256; - - for(index_t i = 0; i <= NumAElementwise; ++i) - { - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); - } - - for(index_t i = 0; i <= NumBElementwise; ++i) - { - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_outs[i], - p_out_grads_packed[i], - d_out_lengths, - d_out_strides, - dim_count, - out_total); - } + // Compute strides from layout + std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); + std::vector wei_strides = + compute_conv_tensor_strides(wei_lengths, ndim); + std::vector out_strides = + compute_conv_tensor_strides(out_lengths, ndim); // Prepare D tensor stride arrays on device std::vector d_stride_bufs; - std::array p_d_strides_dev = {}; + std::array p_d_strides_dev = {}; if constexpr(NumDElementwise > 0) { @@ -508,12 +430,12 @@ void naive_conv_bwd_weight_multi_abd( for(index_t i = 0; i < NumDElementwise; ++i) { - d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); - p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(long_index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], d_strides[i].data(), - d_strides[i].size() * sizeof(index_t), + d_strides[i].size() * sizeof(long_index_t), hipMemcpyHostToDevice)); } } @@ -524,17 +446,16 @@ void naive_conv_bwd_weight_multi_abd( SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); - TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); - TOut** d_out_grads_ptrs = static_cast(out_grads_ptrs_buf.GetDeviceBuffer()); - TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); - index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TOut** d_out_grads_ptrs = static_cast(out_grads_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + long_index_t** d_d_strides_ptrs = + static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); - HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, - p_ins_packed.data(), - (NumAElementwise + 1) * sizeof(TIn*), - hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy( + d_ins_ptrs, p_ins.data(), (NumAElementwise + 1) * sizeof(TIn*), hipMemcpyHostToDevice)); HIP_CHECK_ERROR(hipMemcpy(d_out_grads_ptrs, - p_out_grads_packed.data(), + p_outs.data(), (NumBElementwise + 1) * sizeof(TOut*), hipMemcpyHostToDevice)); @@ -555,21 +476,29 @@ void naive_conv_bwd_weight_multi_abd( } // Build conv parameter vectors for kernel invocation - std::vector conv_strides(ndim); - std::vector conv_dilations(ndim); - std::vector input_pads(ndim); + std::vector conv_strides(ndim); + std::vector conv_dilations(ndim); + std::vector input_pads(ndim); for(index_t i = 0; i < ndim; ++i) { - conv_strides[i] = static_cast(conv_param.conv_filter_strides_[i]); - conv_dilations[i] = static_cast(conv_param.conv_filter_dilations_[i]); - input_pads[i] = static_cast(conv_param.input_left_pads_[i]); + conv_strides[i] = static_cast(conv_param.conv_filter_strides_[i]); + conv_dilations[i] = static_cast(conv_param.conv_filter_dilations_[i]); + input_pads[i] = static_cast(conv_param.input_left_pads_[i]); } - // Run backward weight convolution kernel on packed data - const int wei_grid = (wei_total + block_size - 1) / block_size; + // Run backward weight convolution kernel directly on original tensors using layout strides + constexpr int block_size = 256; + const long_index_t wei_grid_unclamped = (wei_total + block_size - 1) / block_size; + // gridDim.x * blockDim.x must not overflow uint32_t; the kernel uses a grid-stride loop. + constexpr long_index_t max_grid = + static_cast(std::numeric_limits::max()) / block_size; + const int wei_grid = static_cast(std::min(wei_grid_unclamped, max_grid)); if(ndim == 1) { + // in_strides: [sg, sn, sc, sw] (indices 0..3) + // out_strides: [sg, sn, sk, sw] (indices 0..3) + // wei_strides: [sg, sk, sc, sx] (indices 0..3) naive_conv_bwd_weight_packed_multi_abd<1, NumAElementwise, NumBElementwise, @@ -582,7 +511,7 @@ void naive_conv_bwd_weight_multi_abd( WeiElementwiseOperation, OutElementwiseOperation> <<>>(d_ins_ptrs, - p_wei_grad_packed, + p_wei_grad, d_out_grads_ptrs, d_ds_ptrs, d_d_strides_ptrs, @@ -608,12 +537,33 @@ void naive_conv_bwd_weight_multi_abd( 0, 0, input_pads[0], + in_strides[0], + in_strides[1], + in_strides[2], + 0, + 0, + in_strides[3], + out_strides[0], + out_strides[1], + out_strides[2], + 0, + 0, + out_strides[3], + wei_strides[0], + wei_strides[1], + wei_strides[2], + 0, + 0, + wei_strides[3], in_element_op, wei_element_op, out_element_op); } else if(ndim == 2) { + // in_strides: [sg, sn, sc, sh, sw] (indices 0..4) + // out_strides: [sg, sn, sk, sh, sw] (indices 0..4) + // wei_strides: [sg, sk, sc, sy, sx] (indices 0..4) naive_conv_bwd_weight_packed_multi_abd<2, NumAElementwise, NumBElementwise, @@ -626,7 +576,7 @@ void naive_conv_bwd_weight_multi_abd( WeiElementwiseOperation, OutElementwiseOperation> <<>>(d_ins_ptrs, - p_wei_grad_packed, + p_wei_grad, d_out_grads_ptrs, d_ds_ptrs, d_d_strides_ptrs, @@ -652,12 +602,33 @@ void naive_conv_bwd_weight_multi_abd( 0, input_pads[0], input_pads[1], + in_strides[0], + in_strides[1], + in_strides[2], + 0, + in_strides[3], + in_strides[4], + out_strides[0], + out_strides[1], + out_strides[2], + 0, + out_strides[3], + out_strides[4], + wei_strides[0], + wei_strides[1], + wei_strides[2], + 0, + wei_strides[3], + wei_strides[4], in_element_op, wei_element_op, out_element_op); } else // 3D { + // in_strides: [sg, sn, sc, sd, sh, sw] (indices 0..5) + // out_strides: [sg, sn, sk, sd, sh, sw] (indices 0..5) + // wei_strides: [sg, sk, sc, sz, sy, sx] (indices 0..5) naive_conv_bwd_weight_packed_multi_abd<3, NumAElementwise, NumBElementwise, @@ -670,7 +641,7 @@ void naive_conv_bwd_weight_multi_abd( WeiElementwiseOperation, OutElementwiseOperation> <<>>(d_ins_ptrs, - p_wei_grad_packed, + p_wei_grad, d_out_grads_ptrs, d_ds_ptrs, d_d_strides_ptrs, @@ -696,15 +667,29 @@ void naive_conv_bwd_weight_multi_abd( input_pads[0], input_pads[1], input_pads[2], + in_strides[0], + in_strides[1], + in_strides[2], + in_strides[3], + in_strides[4], + in_strides[5], + out_strides[0], + out_strides[1], + out_strides[2], + out_strides[3], + out_strides[4], + out_strides[5], + wei_strides[0], + wei_strides[1], + wei_strides[2], + wei_strides[3], + wei_strides[4], + wei_strides[5], in_element_op, wei_element_op, out_element_op); } - // Unpack weight gradient - strided_copy_kernel<<>>( - p_wei_grad_packed, p_wei_grad, d_wei_lengths, d_wei_strides, dim_count, wei_total); - HIP_CHECK_ERROR(hipGetLastError()); // Memory automatically freed by SimpleDeviceMem destructors @@ -730,11 +715,11 @@ naive_conv_bwd_weight(const TIn* p_in, OutElementwiseOperation out_element_op = OutElementwiseOperation{}, hipStream_t stream = nullptr) { - std::array p_ins = {p_in}; - std::array p_outs = {p_out}; - std::array p_ds = {}; - std::array, 0> d_lengths = {}; - std::array, 0> d_strides = {}; + std::array p_ins = {p_in}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; naive_conv_bwd_weight_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, p_wei_grad, diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp index 7bf9b49998..78d88123cc 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp @@ -15,51 +15,65 @@ namespace ck { namespace ref { -// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support -// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], -// output[G][N][K][spatial] template -__global__ void naive_conv_fwd_packed_multi_abd( - const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra) - const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra) - const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers - const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays - OutDataType* __restrict__ p_out, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void naive_conv_fwd_packed_multi_abd(const InDataType* const* __restrict__ p_ins, + const WeiDataType* const* __restrict__ p_weis, + const DDataType* const* __restrict__ p_ds, + const long_index_t* const* __restrict__ p_d_strides, + OutDataType* __restrict__ p_out, + long_index_t G, + long_index_t N, + long_index_t K, + long_index_t C, + long_index_t Di, + long_index_t Hi, + long_index_t Wi, + long_index_t Z, + long_index_t Y, + long_index_t X, + long_index_t Do, + long_index_t Ho, + long_index_t Wo, + long_index_t stride_z, + long_index_t stride_y, + long_index_t stride_x, + long_index_t dilation_z, + long_index_t dilation_y, + long_index_t dilation_x, + long_index_t pad_z, + long_index_t pad_y, + long_index_t pad_x, + long_index_t in_sg, + long_index_t in_sn, + long_index_t in_sc, + long_index_t in_sd, + long_index_t in_sh, + long_index_t in_sw, + long_index_t wei_sg, + long_index_t wei_sk, + long_index_t wei_sc, + long_index_t wei_sz, + long_index_t wei_sy, + long_index_t wei_sx, + long_index_t out_sg, + long_index_t out_sn, + long_index_t out_sk, + long_index_t out_sd, + long_index_t out_sh, + long_index_t out_sw, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -70,60 +84,47 @@ __global__ void naive_conv_fwd_packed_multi_abd( if constexpr(NDimSpatial == 1) { - const long_index_t num_out = G * N * K * Wo; - const long_index_t in_stride_g = N * C * Wi; - const long_index_t in_stride_n = C * Wi; - const long_index_t in_stride_c = Wi; - const long_index_t wei_stride_g = K * C * X; - const long_index_t wei_stride_k = C * X; - const long_index_t wei_stride_c = X; - const long_index_t out_stride_g = N * K * Wo; - const long_index_t out_stride_n = K * Wo; - const long_index_t out_stride_k = Wo; + const long_index_t num_out = G * N * K * Wo; for(long_index_t idx = tid; idx < num_out; idx += num_threads) { - index_t remaining = idx; - const index_t wo = remaining % Wo; + long_index_t remaining = idx; + const long_index_t wo = remaining % Wo; remaining /= Wo; - const index_t k = remaining % K; + const long_index_t k = remaining % K; remaining /= K; - const index_t n = remaining % N; - const index_t g = remaining / N; + const long_index_t n = remaining % N; + const long_index_t g = remaining / N; - float acc = 0.0f; - // Base pointers for current group, batch, and output channel - const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; - const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + const InDataType* input_g_n = p_ins[0] + g * in_sg + n * in_sn; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_sg + k * wei_sk; for(index_t c = 0; c < C; ++c) { - // Pointers at current input channel - const InDataType* input_at_c = input_g_n + c * in_stride_c; - const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; + const InDataType* input_at_c = input_g_n + c * in_sc; + const WeiDataType* weight_at_c = weight_g_k + c * wei_sc; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - // Handle input element-wise operation with extra A tensors - detail::apply_multi_tensor_elementwise_op( - in_val, - in_op, - input_at_c, - p_ins + 1, - g * in_stride_g + n * in_stride_n + c * in_stride_c, - wi); + detail::apply_multi_tensor_elementwise_op(in_val, + in_op, + input_at_c, + p_ins + 1, + g * in_sg + n * in_sn + + c * in_sc, + wi * in_sw); - // Handle weight element-wise operation with extra B tensors detail::apply_multi_tensor_elementwise_op( wei_val, wei_op, weight_at_c, p_weis + 1, - g * wei_stride_g + k * wei_stride_k + c * wei_stride_c, - x); + g * wei_sg + k * wei_sk + c * wei_sc, + x * wei_sx); acc += type_convert(in_val) * type_convert(wei_val); } @@ -133,81 +134,62 @@ __global__ void naive_conv_fwd_packed_multi_abd( detail::apply_d_tensor_elementwise_op( out_val, out_op, acc, p_ds, p_d_strides, g, n, k, wo); - p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val; + p_out[g * out_sg + n * out_sn + k * out_sk + wo * out_sw] = out_val; } } else if constexpr(NDimSpatial == 2) { - const long_index_t num_out = G * N * K * Ho * Wo; - const long_index_t in_stride_g = N * C * Hi * Wi; - const long_index_t in_stride_n = C * Hi * Wi; - const long_index_t in_stride_c = Hi * Wi; - const long_index_t in_stride_h = Wi; - const long_index_t wei_stride_g = K * C * Y * X; - const long_index_t wei_stride_k = C * Y * X; - const long_index_t wei_stride_c = Y * X; - const long_index_t wei_stride_y = X; - const long_index_t out_stride_g = N * K * Ho * Wo; - const long_index_t out_stride_n = K * Ho * Wo; - const long_index_t out_stride_k = Ho * Wo; - const long_index_t out_stride_h = Wo; + const long_index_t num_out = G * N * K * Ho * Wo; for(long_index_t idx = tid; idx < num_out; idx += num_threads) { - index_t remaining = idx; - const index_t wo = remaining % Wo; + long_index_t remaining = idx; + const long_index_t wo = remaining % Wo; remaining /= Wo; - const index_t ho = remaining % Ho; + const long_index_t ho = remaining % Ho; remaining /= Ho; - const index_t k = remaining % K; + const long_index_t k = remaining % K; remaining /= K; - const index_t n = remaining % N; - const index_t g = remaining / N; + const long_index_t n = remaining % N; + const long_index_t g = remaining / N; - float acc = 0.0f; - // Base pointers for current group, batch, and output channel - const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; - const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + const InDataType* input_g_n = p_ins[0] + g * in_sg + n * in_sn; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_sg + k * wei_sk; for(index_t c = 0; c < C; ++c) { - // Pointers at current input channel - const InDataType* input_at_c = input_g_n + c * in_stride_c; - const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; + const InDataType* input_at_c = input_g_n + c * in_sc; + const WeiDataType* weight_at_c = weight_g_k + c * wei_sc; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - // Pointers at current spatial height and filter Y position - const InDataType* input_at_h = input_at_c + hi * in_stride_h; - const WeiDataType* weight_at_y = weight_at_c + y * wei_stride_y; + const InDataType* input_at_h = input_at_c + hi * in_sh; + const WeiDataType* weight_at_y = weight_at_c + y * wei_sy; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - // Handle input element-wise operation with extra A tensors detail::apply_multi_tensor_elementwise_op( in_val, in_op, input_at_h, p_ins + 1, - g * in_stride_g + n * in_stride_n + c * in_stride_c + - hi * in_stride_h, - wi); + g * in_sg + n * in_sn + c * in_sc + hi * in_sh, + wi * in_sw); - // Handle weight element-wise operation with extra B tensors detail::apply_multi_tensor_elementwise_op( wei_val, wei_op, weight_at_y, p_weis + 1, - g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + - y * wei_stride_y, - x); + g * wei_sg + k * wei_sk + c * wei_sc + y * wei_sy, + x * wei_sx); acc += type_convert(in_val) * type_convert(wei_val); } @@ -227,96 +209,74 @@ __global__ void naive_conv_fwd_packed_multi_abd( ho * p_d_strides[0][3] + wo * p_d_strides[0][4]); - p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] = - out_val; + p_out[g * out_sg + n * out_sn + k * out_sk + ho * out_sh + wo * out_sw] = out_val; } } else if constexpr(NDimSpatial == 3) { - const long_index_t num_out = G * N * K * Do * Ho * Wo; - const long_index_t in_stride_g = N * C * Di * Hi * Wi; - const long_index_t in_stride_n = C * Di * Hi * Wi; - const long_index_t in_stride_c = Di * Hi * Wi; - const long_index_t in_stride_d = Hi * Wi; - const long_index_t in_stride_h = Wi; - const long_index_t wei_stride_g = K * C * Z * Y * X; - const long_index_t wei_stride_k = C * Z * Y * X; - const long_index_t wei_stride_c = Z * Y * X; - const long_index_t wei_stride_z = Y * X; - const long_index_t wei_stride_y = X; - const long_index_t out_stride_g = N * K * Do * Ho * Wo; - const long_index_t out_stride_n = K * Do * Ho * Wo; - const long_index_t out_stride_k = Do * Ho * Wo; - const long_index_t out_stride_d = Ho * Wo; - const long_index_t out_stride_h = Wo; + const long_index_t num_out = G * N * K * Do * Ho * Wo; for(long_index_t idx = tid; idx < num_out; idx += num_threads) { - index_t remaining = idx; - const index_t wo = remaining % Wo; + long_index_t remaining = idx; + const long_index_t wo = remaining % Wo; remaining /= Wo; - const index_t ho = remaining % Ho; + const long_index_t ho = remaining % Ho; remaining /= Ho; - const index_t do_idx = remaining % Do; + const long_index_t do_idx = remaining % Do; remaining /= Do; - const index_t k = remaining % K; + const long_index_t k = remaining % K; remaining /= K; - const index_t n = remaining % N; - const index_t g = remaining / N; + const long_index_t n = remaining % N; + const long_index_t g = remaining / N; - float acc = 0.0f; - // Base pointers for current group, batch, and output channel - const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; - const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + const InDataType* input_g_n = p_ins[0] + g * in_sg + n * in_sn; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_sg + k * wei_sk; for(index_t c = 0; c < C; ++c) { - // Pointers at current input channel - const InDataType* input_at_c = input_g_n + c * in_stride_c; - const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; + const InDataType* input_at_c = input_g_n + c * in_sc; + const WeiDataType* weight_at_c = weight_g_k + c * wei_sc; for(index_t z = 0; z < Z; ++z) { long_index_t di = do_idx * stride_z + z * dilation_z - pad_z; if(di >= 0 && di < Di) { - // Pointers at current spatial depth - const InDataType* input_at_d = input_at_c + di * in_stride_d; - const WeiDataType* weight_at_z = weight_at_c + z * wei_stride_z; + const InDataType* input_at_d = input_at_c + di * in_sd; + const WeiDataType* weight_at_z = weight_at_c + z * wei_sz; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - // Pointers at current spatial depth and height - const InDataType* input_at_d_h = input_at_d + hi * in_stride_h; - const WeiDataType* weight_at_z_y = weight_at_z + y * wei_stride_y; + const InDataType* input_at_d_h = input_at_d + hi * in_sh; + const WeiDataType* weight_at_z_y = weight_at_z + y * wei_sy; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - // Handle input element-wise operation with extra A tensors detail::apply_multi_tensor_elementwise_op( in_val, in_op, input_at_d_h, p_ins + 1, - g * in_stride_g + n * in_stride_n + c * in_stride_c + - di * in_stride_d + hi * in_stride_h, - wi); + g * in_sg + n * in_sn + c * in_sc + di * in_sd + + hi * in_sh, + wi * in_sw); - // Handle weight element-wise operation with extra B tensors detail::apply_multi_tensor_elementwise_op( wei_val, wei_op, weight_at_z_y, p_weis + 1, - g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + - z * wei_stride_z + y * wei_stride_y, - x); + g * wei_sg + k * wei_sk + c * wei_sc + z * wei_sz + + y * wei_sy, + x * wei_sx); acc += type_convert(in_val) * type_convert(wei_val); @@ -339,13 +299,12 @@ __global__ void naive_conv_fwd_packed_multi_abd( k, do_idx * p_d_strides[0][3] + ho * p_d_strides[0][4] + wo * p_d_strides[0][5]); - p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d + - ho * out_stride_h + wo] = out_val; + p_out[g * out_sg + n * out_sn + k * out_sk + do_idx * out_sd + ho * out_sh + + wo * out_sw] = out_val; } } } -// GPU reference convolution with multi-ABD support - takes ConvParam directly template // D tensor type, defaults to TOut for backward compatibility + typename TD = TOut> void naive_conv_fwd_multi_abd( const std::array& p_ins, const std::array& p_weis, const std::array& p_ds, TOut* p_out, const ck::utils::conv::ConvParam& conv_param, - [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, - const std::array, NumDElementwise>& d_strides, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, InElementwiseOperation in_element_op = InElementwiseOperation{}, WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, OutElementwiseOperation out_element_op = OutElementwiseOperation{}, @@ -374,121 +333,35 @@ void naive_conv_fwd_multi_abd( { const auto ndim = conv_param.num_dim_spatial_; - const index_t G = conv_param.G_; - const index_t N = conv_param.N_; - const index_t C = conv_param.C_; - const index_t K = conv_param.K_; + const long_index_t G = conv_param.G_; + const long_index_t N = conv_param.N_; + const long_index_t C = conv_param.C_; + const long_index_t K = conv_param.K_; - std::vector in_lengths = {G, N, C}; - std::vector wei_lengths = {G, K, C}; - std::vector out_lengths = {G, N, K}; + std::vector in_lengths = {G, N, C}; + std::vector wei_lengths = {G, K, C}; + std::vector out_lengths = {G, N, K}; for(index_t i = 0; i < ndim; ++i) { - in_lengths.push_back(static_cast(conv_param.input_spatial_lengths_[i])); - wei_lengths.push_back(static_cast(conv_param.filter_spatial_lengths_[i])); - out_lengths.push_back(static_cast(conv_param.output_spatial_lengths_[i])); + in_lengths.push_back(static_cast(conv_param.input_spatial_lengths_[i])); + wei_lengths.push_back(static_cast(conv_param.filter_spatial_lengths_[i])); + out_lengths.push_back(static_cast(conv_param.output_spatial_lengths_[i])); } - // Calculate total elements for buffer allocation - long_index_t in_total = 1, wei_total = 1, out_total = 1; - for(auto l : in_lengths) - in_total *= l; - for(auto l : wei_lengths) - wei_total *= l; + long_index_t out_total = 1; for(auto l : out_lengths) out_total *= l; - // Allocate packed buffers for all A and B tensors - // Use separate allocations to avoid copy assignment issues with RAII wrapper - std::vector in_packed_bufs; - in_packed_bufs.reserve(NumAElementwise + 1); - for(index_t i = 0; i <= NumAElementwise; ++i) - { - in_packed_bufs.emplace_back(in_total * sizeof(TIn)); - } - - std::vector wei_packed_bufs; - wei_packed_bufs.reserve(NumBElementwise + 1); - for(index_t i = 0; i <= NumBElementwise; ++i) - { - wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); - } - - SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut)); - - // Get packed buffer pointers - std::array p_ins_packed; - for(index_t i = 0; i <= NumAElementwise; ++i) - { - p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); - } - - std::array p_weis_packed; - for(index_t i = 0; i <= NumBElementwise; ++i) - { - p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); - } - - TOut* p_out_packed = static_cast(out_packed_buf.GetDeviceBuffer()); - - // Compute strides and allocate device arrays for pack/unpack - std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); - std::vector wei_strides = compute_conv_tensor_strides(wei_lengths, ndim); - std::vector out_strides = compute_conv_tensor_strides(out_lengths, ndim); - - const size_t dim_count = in_lengths.size(); - SimpleDeviceMem in_lengths_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem in_strides_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem wei_lengths_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem wei_strides_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem out_lengths_buf(dim_count * sizeof(index_t)); - SimpleDeviceMem out_strides_buf(dim_count * sizeof(index_t)); - - index_t* d_in_lengths = static_cast(in_lengths_buf.GetDeviceBuffer()); - index_t* d_in_strides = static_cast(in_strides_buf.GetDeviceBuffer()); - index_t* d_wei_lengths = static_cast(wei_lengths_buf.GetDeviceBuffer()); - index_t* d_wei_strides = static_cast(wei_strides_buf.GetDeviceBuffer()); - index_t* d_out_lengths = static_cast(out_lengths_buf.GetDeviceBuffer()); - index_t* d_out_strides = static_cast(out_strides_buf.GetDeviceBuffer()); - - HIP_CHECK_ERROR(hipMemcpy( - d_in_lengths, in_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_in_strides, in_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_wei_lengths, wei_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_wei_strides, wei_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_out_lengths, out_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy( - d_out_strides, out_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice)); - - // Pack input and weight tensors to contiguous layout - constexpr int block_size = 256; - - // Pack all A tensors - for(index_t i = 0; i <= NumAElementwise; ++i) - { - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); - } - - // Pack all B tensors - for(index_t i = 0; i <= NumBElementwise; ++i) - { - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); - } + std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); + std::vector wei_strides = + compute_conv_tensor_strides(wei_lengths, ndim); + std::vector out_strides = + compute_conv_tensor_strides(out_lengths, ndim); // Prepare D tensor stride arrays on device - // NOTE: D tensors are NOT packed - they are used directly with their original strides - // to support broadcasting (e.g., BiasGK layout with zero strides) std::vector d_stride_bufs; - std::array p_d_strides_dev = {}; + std::array p_d_strides_dev = {}; if constexpr(NumDElementwise > 0) { @@ -496,40 +369,35 @@ void naive_conv_fwd_multi_abd( for(index_t i = 0; i < NumDElementwise; ++i) { - // Allocate and copy strides to device - d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); - p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(long_index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], d_strides[i].data(), - d_strides[i].size() * sizeof(index_t), + d_strides[i].size() * sizeof(long_index_t), hipMemcpyHostToDevice)); } } - // Create device arrays of pointers + // Create device pointer arrays (use original pointers directly, no packing) SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*)); SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); - SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(long_index_t*)); - TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); - TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); - TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); - index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + long_index_t** d_d_strides_ptrs = + static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); - HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, - p_ins_packed.data(), - (NumAElementwise + 1) * sizeof(TIn*), - hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, - p_weis_packed.data(), - (NumBElementwise + 1) * sizeof(TWei*), - hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy( + d_ins_ptrs, p_ins.data(), (NumAElementwise + 1) * sizeof(TIn*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy( + d_weis_ptrs, p_weis.data(), (NumBElementwise + 1) * sizeof(TWei*), hipMemcpyHostToDevice)); if constexpr(NumDElementwise > 0) { - // D tensors use original pointers (not packed) to support broadcasting std::array p_ds_dev; for(index_t i = 0; i < NumDElementwise; ++i) { @@ -540,23 +408,51 @@ void naive_conv_fwd_multi_abd( d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, p_d_strides_dev.data(), - NumDElementwise * sizeof(index_t*), + NumDElementwise * sizeof(long_index_t*), hipMemcpyHostToDevice)); } - // Build conv parameter vectors for kernel invocation - std::vector conv_strides(ndim); - std::vector conv_dilations(ndim); - std::vector input_pads(ndim); + std::vector conv_strides(ndim); + std::vector conv_dilations(ndim); + std::vector input_pads(ndim); for(index_t i = 0; i < ndim; ++i) { - conv_strides[i] = static_cast(conv_param.conv_filter_strides_[i]); - conv_dilations[i] = static_cast(conv_param.conv_filter_dilations_[i]); - input_pads[i] = static_cast(conv_param.input_left_pads_[i]); + conv_strides[i] = static_cast(conv_param.conv_filter_strides_[i]); + conv_dilations[i] = static_cast(conv_param.conv_filter_dilations_[i]); + input_pads[i] = static_cast(conv_param.input_left_pads_[i]); } - // Run convolution kernel on packed data - const int out_grid = (out_total + block_size - 1) / block_size; + // Extract strides indexed as [G,N,C,spatial...] and [G,K,C,spatial...] / [G,N,K,spatial...] + // in_strides: [0]=sg [1]=sn [2]=sc [3]=sd [4]=sh [5]=sw + // wei_strides: [0]=sg [1]=sk [2]=sc [3]=sz [4]=sy [5]=sx + // out_strides: [0]=sg [1]=sn [2]=sk [3]=sd [4]=sh [5]=sw + const long_index_t in_sg = in_strides[0]; + const long_index_t in_sn = in_strides[1]; + const long_index_t in_sc = in_strides[2]; + const long_index_t in_sd = (ndim >= 3) ? in_strides[3] : 0; + const long_index_t in_sh = (ndim >= 2) ? in_strides[ndim == 3 ? 4 : 3] : 0; + const long_index_t in_sw = in_strides[ndim == 3 ? 5 : (ndim == 2 ? 4 : 3)]; + + const long_index_t wei_sg = wei_strides[0]; + const long_index_t wei_sk = wei_strides[1]; + const long_index_t wei_sc = wei_strides[2]; + const long_index_t wei_sz = (ndim >= 3) ? wei_strides[3] : 0; + const long_index_t wei_sy = (ndim >= 2) ? wei_strides[ndim == 3 ? 4 : 3] : 0; + const long_index_t wei_sx = wei_strides[ndim == 3 ? 5 : (ndim == 2 ? 4 : 3)]; + + const long_index_t out_sg = out_strides[0]; + const long_index_t out_sn = out_strides[1]; + const long_index_t out_sk = out_strides[2]; + const long_index_t out_sd = (ndim >= 3) ? out_strides[3] : 0; + const long_index_t out_sh = (ndim >= 2) ? out_strides[ndim == 3 ? 4 : 3] : 0; + const long_index_t out_sw = out_strides[ndim == 3 ? 5 : (ndim == 2 ? 4 : 3)]; + + constexpr int block_size = 256; + const long_index_t out_grid_unclamped = (out_total + block_size - 1) / block_size; + // gridDim.x * blockDim.x must not overflow uint32_t; the kernel uses a grid-stride loop. + constexpr long_index_t max_grid = + static_cast(std::numeric_limits::max()) / block_size; + const int out_grid = static_cast(std::min(out_grid_unclamped, max_grid)); if(ndim == 1) { @@ -575,7 +471,7 @@ void naive_conv_fwd_multi_abd( d_weis_ptrs, d_ds_ptrs, d_d_strides_ptrs, - p_out_packed, + p_out, G, N, K, @@ -598,6 +494,24 @@ void naive_conv_fwd_multi_abd( 0, 0, input_pads[0], + in_sg, + in_sn, + in_sc, + in_sd, + in_sh, + in_sw, + wei_sg, + wei_sk, + wei_sc, + wei_sz, + wei_sy, + wei_sx, + out_sg, + out_sn, + out_sk, + out_sd, + out_sh, + out_sw, in_element_op, wei_element_op, out_element_op); @@ -619,7 +533,7 @@ void naive_conv_fwd_multi_abd( d_weis_ptrs, d_ds_ptrs, d_d_strides_ptrs, - p_out_packed, + p_out, G, N, K, @@ -642,6 +556,24 @@ void naive_conv_fwd_multi_abd( 0, input_pads[0], input_pads[1], + in_sg, + in_sn, + in_sc, + in_sd, + in_sh, + in_sw, + wei_sg, + wei_sk, + wei_sc, + wei_sz, + wei_sy, + wei_sx, + out_sg, + out_sn, + out_sk, + out_sd, + out_sh, + out_sw, in_element_op, wei_element_op, out_element_op); @@ -663,7 +595,7 @@ void naive_conv_fwd_multi_abd( d_weis_ptrs, d_ds_ptrs, d_d_strides_ptrs, - p_out_packed, + p_out, G, N, K, @@ -686,21 +618,32 @@ void naive_conv_fwd_multi_abd( input_pads[0], input_pads[1], input_pads[2], + in_sg, + in_sn, + in_sc, + in_sd, + in_sh, + in_sw, + wei_sg, + wei_sk, + wei_sc, + wei_sz, + wei_sy, + wei_sx, + out_sg, + out_sn, + out_sk, + out_sd, + out_sh, + out_sw, in_element_op, wei_element_op, out_element_op); } - // Unpack - strided_copy_kernel<<>>( - p_out_packed, p_out, d_out_lengths, d_out_strides, dim_count, out_total); - HIP_CHECK_ERROR(hipGetLastError()); - - // Memory automatically freed by SimpleDeviceMem destructors } -// Original naive_conv_fwd - now a zero-overhead wrapper template p_ins = {p_in}; - std::array p_weis = {p_wei}; - std::array p_ds = {}; - std::array, 0> d_lengths = {}; - std::array, 0> d_strides = {}; + std::array p_ins = {p_in}; + std::array p_weis = {p_wei}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, p_weis, diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp index 50b65357a2..40e933e4a9 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp @@ -110,12 +110,12 @@ inline int map_dim_char_to_index(char dim_char, index_t ndim_spatial, bool is_we // Template function to compute layout-aware strides based on layout name // The layout name directly encodes memory ordering from left to right template -inline std::vector compute_conv_tensor_strides(const std::vector& lengths, - index_t ndim_spatial) +inline std::vector +compute_conv_tensor_strides(const std::vector& lengths, index_t ndim_spatial) { constexpr const char* layout_name = Layout::name; const int num_dims = static_cast(lengths.size()); - std::vector strides(num_dims, 0); + std::vector strides(num_dims, 0); // Determine if this is a weight tensor (has 'K' but not 'N') bool has_k = false; @@ -146,7 +146,7 @@ inline std::vector compute_conv_tensor_strides(const std::vector(dim_order.size()) - 1; i >= 0; --i) { char dim_char = dim_order[i]; @@ -168,8 +168,8 @@ inline std::vector compute_conv_tensor_strides(const std::vector __global__ void strided_copy_kernel(const DataType* __restrict__ src, DataType* __restrict__ dst, - const index_t* tensor_lengths, - const index_t* strided_strides, + const long_index_t* tensor_lengths, + const long_index_t* strided_strides, int num_dims, long_index_t total_elements) { @@ -184,7 +184,7 @@ __global__ void strided_copy_kernel(const DataType* __restrict__ src, for(int dim = num_dims - 1; dim >= 0; --dim) { - index_t coord = remaining % tensor_lengths[dim]; + long_index_t coord = remaining % tensor_lengths[dim]; remaining /= tensor_lengths[dim]; strided_idx += coord * strided_strides[dim]; } @@ -253,15 +253,16 @@ __device__ __forceinline__ void apply_d_tensor_impl(OutDataType& result_out, // Specialized helper for D tensors with stride calculations and float conversion template -__device__ __forceinline__ void apply_d_tensor_elementwise_op(OutDataType& result_out, - Op&& element_op, - float computed_value, - const DDataType* const* p_ds, - const index_t* const* p_d_strides, - index_t g, - index_t n, - index_t c_or_k, - long_index_t spatial_linear_index) +__device__ __forceinline__ void +apply_d_tensor_elementwise_op(OutDataType& result_out, + Op&& element_op, + float computed_value, + const DDataType* const* p_ds, + const long_index_t* const* p_d_strides, + long_index_t g, + long_index_t n, + long_index_t c_or_k, + long_index_t spatial_linear_index) { if constexpr(NumDTensors == 0) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp index e2b0cb74ba..bf123a5e00 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp @@ -79,6 +79,100 @@ using device_grouped_conv_bwd_data_xdl_v3_bf16_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_v3_f16_large_tensor_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // AK1=8, BK1=8 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + // ScalarPerVector = 4 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + // ScalarPerVector = 2 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, 1, 1, S<1, 64, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 64, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, 1, 1, S<1, 64, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + // ScalarPerVector = 1 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<2, 128, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<2, 128, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, false, true> + + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_v3_bf16_large_tensor_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // AK1=8, BK1=8 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + // ScalarPerVector = 4 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + // ScalarPerVector = 2 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, 1, 1, S<1, 64, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 64, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, 1, 1, S<1, 64, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + // ScalarPerVector = 1 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<2, 128, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<2, 128, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, false, true> + + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_v3_f32_large_tensor_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // ScalarPerVector = 4 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F32, F32, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F32, F32, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F32, F32, false, true>, + // ScalarPerVector = 2 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, 1, 1, S<1, 64, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F32, F32, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 64, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F32, F32, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, 1, 1, S<1, 64, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F32, F32, false, true>, + // ScalarPerVector = 1 + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F32, F32, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F32, F32, false, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F32, F32, false, true> + + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 3633015a78..a0ad407256 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -384,6 +384,48 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_part2_in // clang-format on >; +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_large_tensor_instances = + std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| Large | + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|Tensors| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, 1, F16, F16, 1, 1, true>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 32, 32, 2, 2, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, 1, F16, F16, 1, 1, true>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 256, 32, 8, 32, 32, 4, 4, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, 1, F16, F16, 1, 1, true>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, 1, F16, F16, 1, 1, true> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_large_tensor_instances = + std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| Large | + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|Tensors| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, 1, BF16, BF16, 1, 1, true>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 32, 32, 2, 2, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, 1, BF16, BF16, 1, 1, true>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 256, 32, 8, 32, 32, 4, 4, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, 1, BF16, BF16, 1, 1, true>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, 1, BF16, BF16, 1, 1, true> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp index c3834c7d17..947946e004 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp @@ -234,6 +234,83 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances = ::std::declval>(), ::std::declval>())); +// large-tensor variants (uses ck::long_index_t for index arithmetic; requires packed tensors) +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_large_tensor_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // vector size 8 + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, BF16, BF16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 32, 32, 2, 4, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 64, 8, 32, 32, 4, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + // vector size 4 + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, BF16, BF16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 32, 32, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 64, 8, 32, 32, 4, 2, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + // vector size 2 + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, BF16, BF16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 32, 32, 2, 4, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 64, 8, 32, 32, 4, 2, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_large_tensor_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // vector size 8 + // vector size 8 + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F16, F16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 32, 32, 2, 4, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 64, 8, 32, 32, 4, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + // vector size 4 + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F16, F16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 32, 32, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 64, 8, 32, 32, 4, 2, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + // vector size 2 + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 64, 8, 32, 32, 2, 2, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F16, F16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 32, 32, 2, 4, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 64, 8, 32, 32, 4, 2, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, S<4, 64, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_large_tensor_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // vector size 4 + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 32, 32, 2, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F32, F32, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 32, 32, 2, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F32, F32, false, 1, true>, + // vector size 2 + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 32, 32, 2, 2, S<4, 32, 2>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 32, 2>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F32, F32, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 32, 32, 2, 2, S<4, 32, 2>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 32, 2>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F32, F32, false, 1, true>, + // vector size 1 + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 32, 32, 2, 2, S<4, 32, 2>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 2, 0, S<4, 32, 2>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F32, F32, false, 1, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 32, 32, 2, 2, S<4, 32, 2>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 2, 0, S<4, 32, 2>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F32, F32, false, 1, true> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index a55038637b..ce9e67476c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -286,6 +286,113 @@ using device_grouped_conv_fwd_xdl_int8_comp_instances_part2 = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; +// large-tensor variants (uses ck::long_index_t for index arithmetic; requires packed tensors) +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_bf16_comp_instances_large_tensors = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // vector size 8 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + // vector size 4 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + // vector size 2 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + // vector size 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, BF16, BF16, false, 1, true> + // clang-format on + >; + +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f16_comp_instances_large_tensors = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // vector size 8 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + // vector size 4 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + // vector size 2 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + // vector size 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F16, F16, false, 1, true> + // clang-format on + >; + +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_comp_instances_large_tensors = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // vector size 4 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F32, F32, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F32, F32, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F32, F32, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F32, F32, false, 1, true>, + // vector size 2 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F32, F32, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F32, F32, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F32, F32, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 0, 1, 1, S<1, 32, 1, 8>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F32, F32, false, 1, true>, + // vector size 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F32, F32, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F32, F32, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F32, F32, false, 1, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F32, F32, false, 1, true> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 09301474f0..f1f4eb3a39 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -110,6 +110,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_large_tensors_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( op_ptrs); @@ -136,6 +138,8 @@ struct DeviceOperationInstanceFactory< #ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { + add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f32_large_tensors_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances( @@ -152,6 +156,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_large_tensors_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( @@ -273,6 +279,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { + add_device_grouped_conv3d_bwd_data_xdl_v3_ndhwgk_gkzyxc_ndhwgc_f16_large_tensors_instances( + op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( @@ -298,6 +306,8 @@ struct DeviceOperationInstanceFactory< #ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { + add_device_grouped_conv3d_bwd_data_xdl_v3_ndhwgk_gkzyxc_ndhwgc_f32_large_tensors_instances( + op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances( @@ -323,6 +333,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { + add_device_grouped_conv3d_bwd_data_xdl_v3_ndhwgk_gkzyxc_ndhwgc_bf16_large_tensors_instances( + op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index 8dae166dd1..9bd6829795 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -262,6 +262,52 @@ void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_large_tensors_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_large_tensors_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f32_large_tensors_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances( std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_v3_ndhwgk_gkzyxc_ndhwgc_f16_large_tensors_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_v3_ndhwgk_gkzyxc_ndhwgc_bf16_large_tensors_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_v3_ndhwgk_gkzyxc_ndhwgc_f32_large_tensors_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_large_tensors_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_large_tensors_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_large_tensors_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_large_tensors_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_large_tensors_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_large_tensors_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_large_tensors_instances( + std::vector>>& instances); + #endif #ifdef CK_ENABLE_TF32 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index b90cd44df0..f3902dba33 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -221,6 +221,8 @@ struct DeviceOperationInstanceFactory>>& instances); #endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_large_tensors_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_large_tensors_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_large_tensors_instances( + std::vector>>& instances); +#endif + // grouped conv2d forward, NGCHW/GKCYX/NGKHW #ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances( @@ -321,6 +369,54 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instan TF32>>>& instances); #endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_large_tensors_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_large_tensors_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_large_tensors_instances( + std::vector>>& instances); +#endif + // grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances( diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/CMakeLists.txt index 67ae53887b..074a7a436f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/CMakeLists.txt @@ -28,6 +28,9 @@ set(GROUPED_CONV2D_BWD_DATA_NHWGC # xdl_v3 xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f32_large_tensors_instance.cpp ) add_instance_library(device_grouped_conv2d_bwd_data_nhwgc_instance ${GROUPED_CONV2D_BWD_DATA_NHWGC}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp new file mode 100644 index 0000000000..ccc494ba9a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_bf16_large_tensor_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp new file mode 100644 index 0000000000..f8d7d48569 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f16_large_tensor_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f32_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f32_large_tensors_instance.cpp new file mode 100644 index 0000000000..7479942156 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/nhwgc/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f32_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f32_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f32_large_tensor_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/CMakeLists.txt index 21a8cf94ef..80087a93e0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/CMakeLists.txt @@ -52,6 +52,11 @@ set(GROUPED_CONV2D_BWD_WEIGHT_NHWGC xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_large_tensors_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_large_tensors_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_large_tensors_instance.cpp ) if(DL_KERNELS) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp new file mode 100644 index 0000000000..c7214769dd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_large_tensor_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp new file mode 100644 index 0000000000..4117c706af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_large_tensors_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_large_tensor_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_large_tensors_instance.cpp new file mode 100644 index 0000000000..2f189a3412 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_large_tensors_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_large_tensor_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_large_tensors_instance.cpp new file mode 100644 index 0000000000..ec1b35592d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_large_tensors_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_large_tensor_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_large_tensors_instance.cpp new file mode 100644 index 0000000000..7410a285f3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/nhwgc/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_large_tensors_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_large_tensor_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/CMakeLists.txt index 319796ecc6..bce0196e5a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/CMakeLists.txt @@ -39,6 +39,9 @@ set(GROUPED_CONV2D_FWD_NHWGC xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_large_tensors_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_large_tensors_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_large_tensors_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_large_tensors_instance.cpp new file mode 100644 index 0000000000..ad854c8726 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_large_tensors<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_large_tensors_instance.cpp new file mode 100644 index 0000000000..e387aa856f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_large_tensors<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_large_tensors_instance.cpp new file mode 100644 index 0000000000..c4ca4ff215 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/nhwgc/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances_large_tensors<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/CMakeLists.txt index 34abb9c9c9..f581af18ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/CMakeLists.txt @@ -25,6 +25,10 @@ set(GROUPED_CONV3D_BWD_DATA_NDHWGC xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_optimized_loads_instance.cpp + # xdl_v3 + xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f32_large_tensors_instance.cpp ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp new file mode 100644 index 0000000000..45d12e7993 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_v3_ndhwgk_gkzyxc_ndhwgc_bf16_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_bf16_large_tensor_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp new file mode 100644 index 0000000000..af01a127cc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_v3_ndhwgk_gkzyxc_ndhwgc_f16_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f16_large_tensor_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f32_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f32_large_tensors_instance.cpp new file mode 100644 index 0000000000..7528af153e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/ndhwgc/xdl/device_grouped_conv3d_bwd_data_xdl_v3_ndhwgc_gkzyxc_ndhwgk_f32_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_v3_ndhwgk_gkzyxc_ndhwgc_f32_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f32_large_tensor_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/CMakeLists.txt index 4e4a2f0f33..8ff358cd8d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/CMakeLists.txt @@ -46,6 +46,11 @@ set(GROUPED_CONV3D_BWD_WEIGHT_NDHWGC xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_large_tensors_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_large_tensors_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_large_tensors_instance.cpp ) if(DL_KERNELS) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp new file mode 100644 index 0000000000..a64ded5680 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_large_tensor_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp new file mode 100644 index 0000000000..eeb631b688 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_large_tensor_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_large_tensors_instance.cpp new file mode 100644 index 0000000000..b79af10483 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_large_tensor_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_large_tensors_instance.cpp new file mode 100644 index 0000000000..a8ab04a0d1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_large_tensor_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_large_tensors_instance.cpp new file mode 100644 index 0000000000..bb0f5a0e01 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/ndhwgc/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_large_tensor_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/CMakeLists.txt index 9306b225eb..72e06b9e99 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/CMakeLists.txt @@ -27,6 +27,9 @@ set(GROUPED_CONV3D_FWD_NDHWGC # xdl xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_large_tensors_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_large_tensors_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_large_tensors_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_large_tensors_instance.cpp new file mode 100644 index 0000000000..3584cc319a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_large_tensors<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_large_tensors_instance.cpp new file mode 100644 index 0000000000..2cdb9fa7cf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_large_tensors<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_large_tensors_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_large_tensors_instance.cpp new file mode 100644 index 0000000000..ff48869fc5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/ndhwgc/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_large_tensors_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_large_tensors_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances_large_tensors<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 836786bd10..bf37cc8d46 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -538,6 +538,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, LogRangeAsType(std::cout << "input: ", input.mData, ",") << std::endl; } + + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; } } else if(do_verification == 1) diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index 8b4df83b44..6c632d8ebb 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -217,15 +217,16 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, else if(do_verification == 2) { // GPU reference - std::vector d_lengths_vec(NDimSpatial + 3); - std::vector d_strides_vec(NDimSpatial + 3); + std::vector d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); d_lengths_vec[0] = conv_param.G_; d_lengths_vec[1] = conv_param.N_; d_lengths_vec[2] = conv_param.K_; for(ck::index_t i = 0; i < NDimSpatial; ++i) { - d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + d_lengths_vec[3 + i] = + static_cast(conv_param.output_spatial_lengths_[i]); } if constexpr(BiasGK) @@ -247,8 +248,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, std::array d_ptrs = { reinterpret_cast(bias_device_buf.GetDeviceBuffer())}; - std::array, 1> d_lengths = {d_lengths_vec}; - std::array, 1> d_strides = {d_strides_vec}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; std::array in_ptrs = { reinterpret_cast(in_device_buf.GetDeviceBuffer())}; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp index 6d9425728f..d07c196bc0 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp @@ -194,15 +194,16 @@ bool profile_grouped_conv_fwd_bilinear_impl( else if(do_verification == 2) { // GPU reference - std::vector d_lengths_vec(NDimSpatial + 3); - std::vector d_strides_vec(NDimSpatial + 3); + std::vector d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); d_lengths_vec[0] = conv_param.G_; d_lengths_vec[1] = conv_param.N_; d_lengths_vec[2] = conv_param.K_; for(ck::index_t i = 0; i < NDimSpatial; ++i) { - d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + d_lengths_vec[3 + i] = + static_cast(conv_param.output_spatial_lengths_[i]); } // D tensor has same layout as output @@ -210,8 +211,8 @@ bool profile_grouped_conv_fwd_bilinear_impl( std::array d_ptrs = { reinterpret_cast(d_device_buf.GetDeviceBuffer())}; - std::array, 1> d_lengths = {d_lengths_vec}; - std::array, 1> d_strides = {d_strides_vec}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; std::array in_ptrs = { reinterpret_cast(in_device_buf.GetDeviceBuffer())}; diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index cc7ce88996..30e2e414d8 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -328,13 +328,13 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) return 1; } - const auto data_type = static_cast(std::stoi(argv[2])); - const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[4]); - const int init_method = std::stoi(argv[5]); - const bool do_log = std::stoi(argv[6]); - const bool time_kernel = std::stoi(argv[7]); - const int num_dim_spatial = std::stoi(argv[8]); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const int do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K if(positional_argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1) diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index b15b639c05..d25f4ee560 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -376,13 +376,13 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) return 1; } - const auto data_type = static_cast(std::stoi(argv[2])); - const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[4]); - const int init_method = std::stoi(argv[5]); - const bool do_log = std::stoi(argv[6]); - const bool time_kernel = std::stoi(argv[7]); - const int num_dim_spatial = std::stoi(argv[8]); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const int do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K if(positional_argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1) diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 301736b4ad..b7223d4843 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -421,14 +421,14 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) return 1; } - const auto data_type = static_cast(std::stoi(argv[2])); - const auto layout = static_cast(std::stoi(argv[3])); - const auto index_type = static_cast(std::stoi(argv[4])); - const bool do_verification = std::stoi(argv[5]); - const int init_method = std::stoi(argv[6]); - const bool do_log = std::stoi(argv[7]); - const bool time_kernel = std::stoi(argv[8]); - const int num_dim_spatial = std::stoi(argv[9]); + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const int do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial if(positional_argc != 9 + 1 + 4 + 6 * num_dim_spatial) diff --git a/python/ck4inductor/grouped_conv_fwd/op.py b/python/ck4inductor/grouped_conv_fwd/op.py index 576c36f66d..f47c9a12ed 100644 --- a/python/ck4inductor/grouped_conv_fwd/op.py +++ b/python/ck4inductor/grouped_conv_fwd/op.py @@ -68,6 +68,7 @@ class CKGroupedConvFwdOp: direct_load: Optional[bool] = None num_groups_to_merge: Optional[int] = None + large_tensor: Optional[bool] = None def name(self): # cpp alias for template instance diff --git a/test/gpu_reference/gpu_reference_utils.hpp b/test/gpu_reference/gpu_reference_utils.hpp index 88306d51a4..12742a2bb7 100644 --- a/test/gpu_reference/gpu_reference_utils.hpp +++ b/test/gpu_reference/gpu_reference_utils.hpp @@ -404,22 +404,22 @@ bool test_conv_fwd_with_d_tensor_impl(const ck::utils::conv::ConvParam& params, using WeiElementOp = tensor_operation::element_wise::PassThrough; // Create D tensor lengths and strides for GPU reference - std::vector d_lengths_vec(NDimSpatial + 3); + std::vector d_lengths_vec(NDimSpatial + 3); d_lengths_vec[0] = params.G_; d_lengths_vec[1] = params.N_; d_lengths_vec[2] = params.K_; for(index_t i = 0; i < NDimSpatial; ++i) { - d_lengths_vec[3 + i] = static_cast(params.output_spatial_lengths_[i]); + d_lengths_vec[3 + i] = static_cast(params.output_spatial_lengths_[i]); } - std::vector d_strides_vec = + std::vector d_strides_vec = ref::compute_conv_tensor_strides(d_lengths_vec, params.num_dim_spatial_); std::array d_ptrs = { reinterpret_cast(d_dev.GetDeviceBuffer())}; - std::array, 1> d_lengths = {d_lengths_vec}; - std::array, 1> d_strides = {d_strides_vec}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; // Call GPU reference with D tensor std::array in_ptrs = { @@ -536,9 +536,9 @@ bool test_conv_fwd_with_multi_ab_impl(const ck::utils::conv::ConvParam& params, std::array wei_ptrs = { reinterpret_cast(weight_dev.GetDeviceBuffer()), reinterpret_cast(b_extra_dev.GetDeviceBuffer())}; - std::array d_ptrs = {}; - std::array, 0> d_lengths = {}; - std::array, 0> d_strides = {}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( in_ptrs, diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp index 24259e1524..5f8145377c 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp @@ -108,8 +108,8 @@ class TestGroupedConvndBwdData : public ::testing::Test conv_param); // Prepare D tensor with correct strides for GPU kernel - std::vector d_lengths; - std::vector d_strides; + std::vector d_lengths; + std::vector d_strides; auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { const auto& l = desc.GetLengths(); const auto& s = desc.GetStrides(); @@ -118,8 +118,8 @@ class TestGroupedConvndBwdData : public ::testing::Test }; copy_dims(in_g_n_c_wis_desc, d_lengths, d_strides); - std::array, NumDs> d_lengths_array = {d_lengths}; - std::array, NumDs> d_strides_array = {d_strides}; + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; DeviceMem d_device_buf(sizeof(InDataType) * d.mDesc.GetElementSpaceSize()); d_device_buf.ToDevice(d.mData.data()); diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp index 64d1bbbee7..8eef152327 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp @@ -10,8 +10,6 @@ #include #include "profiler/profile_grouped_conv_bwd_data_impl.hpp" -static ck::index_t param_mask = 0xffff; -static ck::index_t instance_index = -1; template class TestGroupedConvndBwdData : public ::testing::Test @@ -32,27 +30,21 @@ class TestGroupedConvndBwdData : public ::testing::Test bool pass = true; for(auto split_k : split_ks) { - for(size_t i = 0; i < conv_params.size(); i++) + for(auto& param : conv_params) { - if((param_mask & (1 << i)) == 0) - { - continue; - } - auto& param = conv_params[i]; - pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( - true, // do_verification + pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel param, - split_k, - instance_index); + split_k); } } EXPECT_TRUE(pass); @@ -61,29 +53,11 @@ class TestGroupedConvndBwdData : public ::testing::Test using namespace ck::tensor_layout::convolution; -using KernelTypes2d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, +using KernelTypes2d = ::testing::Types, std::tuple, std::tuple>; -using KernelTypes3d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, +using KernelTypes3d = ::testing::Types, std::tuple, std::tuple>; @@ -103,16 +77,25 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d); TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) { this->conv_params.clear(); - // SplitN case + // Case larger than 2GB this->conv_params.push_back( {2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back( + {2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}}); + // When image is larger than 2GB + this->conv_params.push_back( + {2, 2, 2, 128, 128, {3, 3}, {4096, 2048}, {300, 300}, {3, 3}, {1, 1}, {1, 1}}); + // Split N and G > 1 + this->conv_params.push_back( + {2, 4, 112, 8, 8, {3, 3}, {469, 724}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}); this->template Run<2>(); } TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) { this->conv_params.clear(); - // SplitN case + // Case larger than 2GB this->conv_params.push_back({3, 1, 128, @@ -124,22 +107,29 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back({3, + 32, + 64, + 1, + 1, + {2, 2, 2}, + {360, 2, 672}, + {360, 2, 672}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}}); + // When image is larger than 2GB + this->conv_params.push_back({3, + 1, + 2, + 128, + 128, + {3, 1, 3}, + {900, 2, 2048}, + {300, 1, 300}, + {3, 2, 3}, + {1, 1, 1}, + {1, 1, 1}}); this->template Run<3>(); } - -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) {} - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -} diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index ad2a221ec8..43a59b90a3 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -8,6 +8,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_device_conv_libraries_if_exist(test_grouped_convnd_bwd_weight PRIVATE utility device_conv_operations) + add_executable(test_grouped_convnd_bwd_weight_large_cases test_grouped_convnd_bwd_weight_large_cases.cpp) + target_compile_options(test_grouped_convnd_bwd_weight_large_cases PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_convnd_bwd_weight_large_cases PRIVATE gtest_main getopt::getopt utility device_conv_operations) + add_gtest_executable(test_grouped_convnd_bwd_weight_bilinear test_grouped_convnd_bwd_weight_bilinear.cpp) target_link_device_conv_libraries_if_exist(test_grouped_convnd_bwd_weight_bilinear PRIVATE utility device_conv_operations) add_gtest_executable(test_grouped_convnd_bwd_weight_scale test_grouped_convnd_bwd_weight_scale.cpp) diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp index 801899f94d..6c72f4f4ed 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp @@ -101,8 +101,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test conv_param); // Prepare D tensor with correct strides for GPU kernel - std::vector d_lengths; - std::vector d_strides; + std::vector d_lengths; + std::vector d_strides; auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { const auto& l = desc.GetLengths(); const auto& s = desc.GetStrides(); @@ -111,8 +111,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test }; copy_dims(wei_g_k_c_xs_desc, d_lengths, d_strides); - std::array, NumDs> d_lengths_array = {d_lengths}; - std::array, NumDs> d_strides_array = {d_strides}; + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; DeviceMem d_device_buf(sizeof(WeiDataType) * d.mDesc.GetElementSpaceSize()); d_device_buf.ToDevice(d.mData.data()); diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_dataset_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_dataset_xdl.cpp index d90824db62..d8b005644e 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_dataset_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_dataset_xdl.cpp @@ -20,8 +20,6 @@ #pragma clang diagnostic ignored "-Wlifetime-safety-invalidation" #endif using namespace ck::tensor_layout::convolution; -static ck::index_t param_mask [[maybe_unused]] = 0xffff; -static ck::index_t instance_index = -1; // Load CSV data for 2D tests static std::vector Get2DTestCases() { @@ -96,7 +94,7 @@ bool RunConvBwdWeightTest(const ck::utils::conv::ConvParam& param, ck::index_t s false, // time_kernel param, // ConvParam std::to_string(split_k), // Split-K value as string - instance_index); // instance_index + -1); // instance_index } // 2D Tests - NHWGK layout - Float - SplitK=1 @@ -267,22 +265,6 @@ INSTANTIATE_TEST_SUITE_P(Dataset, TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK2, ::testing::ValuesIn(Get3DTestCases())); -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) {} - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -} #if __clang_major__ >= 23 #pragma clang diagnostic pop #endif diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_large_cases.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_large_cases.cpp new file mode 100644 index 0000000000..9c2a216f71 --- /dev/null +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_large_cases.cpp @@ -0,0 +1,164 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/host_utility/device_prop.hpp" + +#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" + +using namespace ck::tensor_layout::convolution; + +template +class TestGroupedConvndBwdWeight : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + + std::vector conv_params; + std::vector split_ks{1, 2}; + + bool skip_case(const ck::index_t split_k) + { + // 1d NWGC is only supported by DL kernel + // DL kernel is only supported for split_k=1 + if constexpr(std::is_same_v && std::is_same_v) + { + if(split_k != 1) + { + return true; + } + } + + return false; + } + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + + for(auto split_k : split_ks) + { + for(size_t i = 0; i < conv_params.size(); i++) + { + auto& param = conv_params[i]; + if(!skip_case(split_k)) + { + const bool success = + ck::profiler::profile_grouped_conv_bwd_weight_impl( + 2, // do_verification + 2, // init_method: integer value + false, // do_log + false, // time_kernel + param, + std::to_string(split_k), + -1); + pass = pass && success; + if(!success) + std::cout << "Case " << param << " failed!" << std::endl; + } + } + } + EXPECT_TRUE(pass); + } +}; + +template +class TestGroupedConvndBwdWeight2d : public TestGroupedConvndBwdWeight +{ +}; + +template +class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight +{ +}; + +using KernelTypes2d = ::testing::Types, + std::tuple, + std::tuple>; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) +{ + this->conv_params.clear(); + // Case larger than 2GB + this->conv_params.push_back( + {2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back( + {2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}}); + // When image is larger than 2GB + this->conv_params.push_back( + {2, 2, 2, 128, 128, {3, 3}, {4096, 2048}, {300, 300}, {3, 3}, {1, 1}, {1, 1}}); + // Split N and G > 1 + this->conv_params.push_back( + {2, 4, 112, 8, 8, {3, 3}, {469, 724}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) +{ + this->conv_params.clear(); + // Case larger than 2GB + this->conv_params.push_back({3, + 1, + 128, + 4, + 192, + {2, 2, 2}, + {2, 224, 224}, + {1, 224, 224}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back({3, + 32, + 64, + 1, + 1, + {2, 2, 2}, + {360, 2, 672}, + {360, 2, 672}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}}); + // When image is larger than 2GB + this->conv_params.push_back({3, + 1, + 2, + 128, + 128, + {3, 1, 3}, + {900, 2, 2048}, + {300, 1, 300}, + {3, 2, 3}, + {1, 1, 1}, + {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp index c270ae6491..6452f345fe 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases.cpp @@ -8,8 +8,7 @@ #include #include "profiler/profile_grouped_conv_fwd_impl.hpp" -static ck::index_t param_mask = 0xffff; -static ck::index_t instance_index = -1; + template class TestGroupedConvndFwd : public ::testing::Test { @@ -27,30 +26,23 @@ class TestGroupedConvndFwd : public ::testing::Test { EXPECT_FALSE(conv_params.empty()); bool pass = true; - for(size_t i = 0; i < conv_params.size(); i++) + for(auto& param : conv_params) { - if((param_mask & (1 << i)) == 0) - { - continue; - } - auto& param = conv_params[i]; - pass = pass && ck::profiler::profile_grouped_conv_fwd_impl( - true, // do_verification + pass = pass && ck::profiler::profile_grouped_conv_fwd_impl( + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel - param, - ck::tensor_operation::element_wise::PassThrough{}, - instance_index); + param); } EXPECT_TRUE(pass); } @@ -60,8 +52,7 @@ using namespace ck::tensor_layout::convolution; using KernelTypes2d = ::testing::Types, std::tuple, - std::tuple, - std::tuple>; + std::tuple>; using KernelTypes3d = ::testing::Types, std::tuple, @@ -137,20 +128,3 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D) {1, 1, 1}}); this->template Run<3>(); } - -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) {} - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -} diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp index be8f04a7d0..aac43acf2c 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp @@ -197,9 +197,9 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification, std::array wei_ptrs = { reinterpret_cast(wei_device_buf.GetDeviceBuffer()), reinterpret_cast(wei_bias_device_buf.GetDeviceBuffer())}; - std::array d_ptrs = {}; - std::array, 0> d_lengths = {}; - std::array, 0> d_strides = {}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; ck::ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( in_ptrs,