From d413c30ff4dc9e7cc49a8780708a1369f1b6ce80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 14 Jun 2024 16:53:03 +0200 Subject: [PATCH] Support large tensors in grouped conv fwd (#1332) * Support large tensors in grouped conv fwd * Multi ABD fixes * Fix calculate element space size [ROCm/composable_kernel commit: dc1e9c5df9e022b130337cc31fd8a32f6ce1efa7] --- .../impl/device_column_to_image_impl.hpp | 5 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 11 +- .../device_grouped_conv_bwd_weight_dl.hpp | 9 +- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 9 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 20 +-- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 9 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 18 +- ...ice_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp | 9 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 165 ++++++++++++------ ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 125 +++++++------ ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 9 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 9 +- .../device/impl/device_grouped_conv_utils.hpp | 60 +++---- .../impl/device_image_to_column_impl.hpp | 5 +- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 20 +-- .../transform_conv_fwd_to_gemm.hpp | 95 ++++++++-- .../test_grouped_convnd_fwd.cpp | 16 ++ 17 files changed, 369 insertions(+), 225 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index 4c6546239b..a7a366ffbc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl independent_filter_strides, conv_filter_dilations, input_left_pads_with_offset, - input_right_pads); + input_right_pads, + N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index c0fa9ad882..409e8c7b8b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -93,12 +93,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index bd264a3c81..83db2485a1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -54,12 +54,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 3c33c7dbc1..380a06e0d8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -66,12 +66,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index c704cf059e..963f3f254c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -59,12 +59,9 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -116,12 +113,9 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy @@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.Conv_G_; std::array in_out_batch_strides = { - arg.compute_ptr_offset_of_batch_.BatchStrideC_}; + static_cast(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; const auto kernel = kernel_batched_elementwise, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 96854e9a8d..3babd1896f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -61,12 +61,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 7cfbd8a8f3..3bb53920b2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -97,12 +97,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + a_g_n_c_wis_lengths[I1]); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK const std::array& e_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 6a4d97d7d2..5c9d63e2b0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd& c_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(c_g_n_k_wos_lengths, - c_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 28ad91efdd..88fe38adde 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -69,7 +69,8 @@ template @@ -85,7 +86,7 @@ __global__ void const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, - const index_t batch_count, + const index_t groups_count, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -93,18 +94,22 @@ __global__ void const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + const ComputePtrOffsetOfG compute_ptr_offset_of_groups, + const ComputePtrOffsetOfN compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) - // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + // offset base pointer for each work-group + const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); + const index_t& num_blocks_per_n = groups_count; + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + + const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); + const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + + const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -121,13 +126,28 @@ __global__ void AsPointer p_as_grid_grp; BsPointer p_bs_grid_grp; - const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); - static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); - static_for<0, NumATensor, 1>{}( - [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + if constexpr(isMultiA) + { + const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx); - const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}([&](auto i) { + p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i]; + }); + } + else + { + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + static_for<0, 1, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; }); + } + + const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); static_for<0, NumBTensor, 1>{}( @@ -137,7 +157,7 @@ __global__ void p_as_grid_grp, p_bs_grid_grp, p_ds_grid_grp, - p_e_grid + e_batch_offset, + p_e_grid + e_batch_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -150,16 +170,16 @@ __global__ void } else { - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); + + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); GridwiseGemm::template Run( - p_as_grid + a_batch_offset, + p_as_grid + a_batch_offset + a_n_offset, p_bs_grid + b_batch_offset, p_ds_grid_grp, - p_e_grid + e_batch_offset, + p_e_grid + e_batch_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -175,7 +195,7 @@ __global__ void ignore = p_bs_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = batch_count; + ignore = groups_count; ignore = a_grid_desc_k0_m_k1; ignore = b_grid_desc_k0_n_k1; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; @@ -183,7 +203,8 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = cde_element_op; - ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; ignore = block_2_ctile_map; #endif } @@ -309,7 +330,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t Conv_N) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, @@ -321,7 +343,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + Conv_N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -347,11 +370,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template static auto MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + const std::array& e_g_n_k_wos_strides, + const index_t Conv_N) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -363,24 +387,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Pass e_g_n_k_wos_lengths for logical broadcast. static auto MakeDsGridDescriptor_M_N( const std::array& e_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides) + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const index_t Conv_N) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - ds_g_n_k_wos_strides[i]); + return DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N); }, Number{}); } // desc for problem definition using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}, 1))>; // If we are using multiAB and one of the template datatype parameters is not a tuple, convert // it to it @@ -468,6 +493,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_N_per_block_{ + conv_to_gemm_transformer.template GetSplitedNSize( + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, @@ -477,12 +508,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads)}, + input_right_pads, + conv_N_per_block_)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, a_grid_desc_ak0_m_ak1_{ GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ @@ -490,7 +522,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, - compute_ptr_offset_of_batch_{}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -511,8 +544,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle if constexpr(isMultiA || isMultiB) { static_for<0, NumATensor, 1>{}([&](auto i) { - // Init compute_ptr_offset_of_batch_ for multiple AB - compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data // type is not tuple) @@ -524,16 +557,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { // p_as is tuple p_as_grid_(i) = static_cast(p_as[i.value]); + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + compute_ptr_offset_of_n_.BatchStrideA_(i) = + a_g_n_c_wis_strides[1] * conv_N_per_block_; } else { // if MultiB and not MultiA then p_as is single pointer p_as_grid_(i) = static_cast(p_as); + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_c_wis_strides[1] * conv_N_per_block_; } }); static_for<0, NumBTensor, 1>{}([&](auto i) { - // Init compute_ptr_offset_of_batch_ for multiple AB - compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; using DataType = remove_cvref_t>; // It is possible that one of the AB is a pointer and one is a tuple. @@ -553,8 +593,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } else { - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; // p_as and p_bs are pointers p_as_grid_(I0) = static_cast(p_as); @@ -570,13 +611,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_(i) = static_cast(p_ds[i]); // D batch stride - compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; // D desc ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_); }); - compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; // populate desc for Ds/E if constexpr(isMultiA || isMultiB) @@ -638,6 +682,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // tensor descriptors for problem definiton index_t num_group_; + index_t conv_N_per_block_; + AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_M_N ds_grid_desc_m_n_; @@ -655,7 +701,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // for computing batch offset ComputePtrOffsetOfStridedBatch - compute_ptr_offset_of_batch_; + compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -689,8 +736,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.Print(); } - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N; + const index_t gdz = 1; const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); @@ -721,6 +772,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, isMultiA, isMultiB>; @@ -728,7 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return launch_and_time_kernel( stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_as_grid_, @@ -744,7 +796,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_etile_map_, - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); } else { @@ -763,6 +816,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, isMultiA, isMultiB>; @@ -770,7 +824,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return launch_and_time_kernel( stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple @@ -786,7 +840,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_etile_map_, - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 986c41c518..ba9d967e97 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -60,7 +60,7 @@ template (compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); + + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -97,9 +99,9 @@ __global__ void CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(karg.p_a_grid + a_batch_offset, + TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset, karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, p_shared, karg, a_grid_desc_ak0_m_ak1, @@ -114,7 +116,7 @@ template (compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); + + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy @@ -154,9 +159,9 @@ __global__ void CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(karg.p_a_grid + a_batch_offset, + TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset, karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, p_shared_0, p_shared_1, karg, @@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t Conv_N) + { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, @@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + Conv_N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 template static auto MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + const std::array& e_g_n_k_wos_strides, + const index_t Conv_N) + { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } // desc for problem definition - using EGridDesc_M_N = remove_cvref_t({}, {}))>; + using EGridDesc_M_N = remove_cvref_t({}, {}, 1))>; #define GridwiseGemmV3TemplateParams \ tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ @@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // desc for blockwise copy using AGridDesc_AK0_M_AK1 = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t({}, {}))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 p_b_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_N_per_block_{ + conv_to_gemm_transformer.template GetSplitedNSize( + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, @@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads)}, + input_right_pads, + conv_N_per_block_)}, b_grid_desc_bk0_n_bk1_{ MakeBGridDescriptor_BK0_N_BK1(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - compute_ptr_offset_of_batch_{}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { - // A/B/E Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + // A/B/E Batch/N Stride + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; // p_as and p_bs are pointers p_a_grid_ = static_cast(p_as); p_b_grid_ = static_cast(p_bs); - compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; e_grid_desc_mblock_mperblock_nblock_nperblock_ = MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); @@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // tensor descriptors for problem definiton index_t num_group_; + index_t conv_N_per_block_; // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; @@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + index_t gdx, gdy, gdz; std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); - gdy *= arg.num_group_; + gdy *= arg.num_group_ * num_workgroups_per_Conv_N; index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); @@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, arg.num_group_); } else @@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, arg.num_group_); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index ab1c4fc08f..114fcbfcff 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + a_g_n_c_wis_lengths[I1]); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle const std::array& e_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 24bd0f242c..d5cc5dc758 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + a_g_n_c_wis_lengths[I1]); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const std::array& e_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp index 9ae10441f9..c20e5d36f8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -68,14 +68,14 @@ template struct ComputePtrOffsetOfStridedBatch 1 || NumBTensor > 1)>> + enable_if_t<(NumATensor > 1 || NumBTensor > 1)>> { ComputePtrOffsetOfStridedBatch() = default; - ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, - Array& BatchStrideBs, - Array& BatchStrideDs, - index_t BatchStrideE) + ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, + Array& BatchStrideBs, + Array& BatchStrideDs, + long_index_t BatchStrideE) : BatchStrideA_(BatchStrideAs), BatchStrideB_(BatchStrideBs), BatchStrideDs_(BatchStrideDs), @@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch as_offset; static_for<0, NumATensor, 1>{}( - [&](auto i) { as_offset(i) = g_idx * static_cast(BatchStrideA_[i]); }); + [&](auto i) { as_offset(i) = static_cast(g_idx) * BatchStrideA_[i]; }); return as_offset; } @@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch bs_offset; static_for<0, NumBTensor, 1>{}( - [&](auto i) { bs_offset(i) = g_idx * static_cast(BatchStrideB_[i]); }); + [&](auto i) { bs_offset(i) = static_cast(g_idx) * BatchStrideB_[i]; }); return bs_offset; } @@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch ds_offset; static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); return ds_offset; } [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } // alias for kernels without multiple D [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } - Array BatchStrideA_; - Array BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; - index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D + Array BatchStrideA_; + Array BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; template struct ComputePtrOffsetOfStridedBatch> + enable_if_t<(NumATensor == 1 && NumBTensor == 1)>> { ComputePtrOffsetOfStridedBatch() = default; - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) + ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, + long_index_t BatchStrideB, + Array BatchStrideDs, + long_index_t BatchStrideE) : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideDs_(BatchStrideDs), @@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch(BatchStrideA_); + return static_cast(g_idx) * BatchStrideA_; } __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideB_); + return static_cast(g_idx) * BatchStrideB_; } __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const { Array ds_offset; static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); return ds_offset; } [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } // alias for kernels without multiple D [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } - ck::index_t BatchStrideA_; - ck::index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; - index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D + long_index_t BatchStrideA_; + long_index_t BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; template diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index 52aeefa3a4..9ebcb2b8c0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 82d010a99a..dc639e995e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,12 +60,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -155,12 +152,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index e2f75142d4..3097a32937 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,6 +20,71 @@ struct TransformConvFwdToGemm static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + static long_index_t + calculate_element_space_size_impl(const std::array& lengths, + const std::array& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static index_t GetSplitedNSize(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const index_t N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(index_t least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Not possible to support even after split N. + // Too large tensor. + return N; + } + } + else + { + // Split N is not needed. + return N; + } + } + // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // properties template & conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t N) { - const index_t N = a_g_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2]; const index_t Wi = a_g_n_c_wis_lengths[3]; @@ -151,9 +216,10 @@ struct TransformConvFwdToGemm const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t N) + { - const index_t N = a_g_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2]; const index_t Hi = a_g_n_c_wis_lengths[3]; @@ -276,13 +342,14 @@ struct TransformConvFwdToGemm const std::array& b_g_k_c_xs_lengths, const std::array& /* b_g_k_c_xs_strides */, const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, + const std::array& /* c_g_n_k_wos_strides*/, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t N) + { - const index_t N = a_g_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2]; const index_t Di = a_g_n_c_wis_lengths[3]; @@ -478,9 +545,9 @@ struct TransformConvFwdToGemm bool>::type = false> static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */) + const std::array& /* c_g_n_k_wos_strides */, + const index_t N) { - const index_t N = c_g_n_k_wos_lengths[1]; const index_t K = c_g_n_k_wos_lengths[2]; const index_t NHoWo = @@ -502,9 +569,9 @@ struct TransformConvFwdToGemm is_same_v, bool>::type = false> static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) + const std::array& c_g_n_k_wos_strides, + const index_t N) { - const index_t N = c_g_n_k_wos_lengths[1]; const index_t K = c_g_n_k_wos_lengths[2]; const auto KStride = I1; @@ -525,9 +592,9 @@ struct TransformConvFwdToGemm typename std::enable_if, bool>::type = false> static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) + const std::array& c_g_n_k_wos_strides, + const index_t N) { - const index_t N = c_g_n_k_wos_lengths[1]; const index_t K = c_g_n_k_wos_lengths[2]; const index_t KStride = c_g_n_k_wos_strides[2]; diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 125e4dc48c..21fe7992ac 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -69,6 +69,8 @@ using KernelTypes3d = ::testing::Types std::tuple, std::tuple>; +using KernelTypes2dLargeCases = ::testing::Types>; + template class TestGroupedConvndFwd1d : public TestGroupedConvndFwd { @@ -84,9 +86,15 @@ class TestGroupedConvndFwd3d : public TestGroupedConvndFwd { }; +template +class TestGroupedConvndFwd2dLargeCases : public TestGroupedConvndFwd +{ +}; + TYPED_TEST_SUITE(TestGroupedConvndFwd1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); +TYPED_TEST_SUITE(TestGroupedConvndFwd2dLargeCases, KernelTypes2dLargeCases); TYPED_TEST(TestGroupedConvndFwd1d, Test1D) { @@ -131,3 +139,11 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D) {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } + +TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases) +{ + // Case larger than 2GB + this->conv_params.push_back( + {2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {0, 0}, {0, 0}, {0, 0}}); + this->template Run<2>(); +}