From c885afdaae453a744d76e989dda7dda7e4f7e1cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 12 Jul 2024 20:08:42 +0200 Subject: [PATCH] Support access per groups and filter3x3 in grouped conv fwd (#1382) * Support access per groups and filter3x3 in grouped conv fwd * Fixes for large cases * Fixes for large tensors [ROCm/composable_kernel commit: 82e8a78a3f5ed8906162bf48d22fdf525c99aa12] --- .../convolution_forward_specialization.hpp | 4 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 76 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 127 ++- .../device/impl/device_grouped_conv_utils.hpp | 16 + .../transform_conv_bwd_weight_to_gemm_v2.hpp | 90 +- .../transform_conv_fwd_to_gemm.hpp | 961 ++++++++++++++---- ...conv_bwd_weight_two_stage_xdl_instance.hpp | 8 +- ...ed_conv_fwd_xdl_merged_groups_instance.hpp | 96 ++ .../gpu/grouped_convolution_forward.hpp | 13 + ..._convolution_forward_xdl_merged_groups.inc | 112 ++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 5 + ...groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 48 + ..._groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 48 + ..._groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 48 + .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 4 + ...ups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 47 + ...oups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 47 + ...oups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 47 + .../test_grouped_convnd_fwd.cpp | 10 +- 19 files changed, 1471 insertions(+), 336 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index adfa1689c6..0eef827a5b 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ enum struct ConvolutionForwardSpecialization Filter1x1Pad0, Filter1x1Stride1Pad0, OddC, + Filter3x3, }; inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) @@ -25,6 +26,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; case ConvolutionForwardSpecialization::OddC: return "OddC"; + case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3"; default: return "Unrecognized specialization!"; } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index d9e300b737..e18b8b9e28 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -36,7 +36,7 @@ template struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle @@ -238,7 +238,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle NPerBlock, K1Number, KPerBlock / K1Number, - NumBatchToMerge, + NumGroupsToMerge, ConvBackwardWeightSpecialization>{}; static constexpr auto conv_to_gemm_transformer_v1 = @@ -638,7 +638,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle index_t gdx, gdy, gdz; std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( - gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumBatchToMerge); + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge); float ave_time = 0; @@ -724,7 +724,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -739,7 +739,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -760,7 +760,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -777,7 +777,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -796,7 +796,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -817,7 +817,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -838,7 +838,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -859,7 +859,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -879,7 +879,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -900,7 +900,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -920,7 +920,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -937,7 +937,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -956,7 +956,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -977,7 +977,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -998,7 +998,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1019,7 +1019,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1039,7 +1039,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1060,7 +1060,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1084,7 +1084,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -1100,7 +1100,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -1119,7 +1119,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1135,7 +1135,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1157,7 +1157,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -1173,7 +1173,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -1192,7 +1192,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1208,7 +1208,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1232,7 +1232,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, false, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -1247,7 +1247,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, false, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -1389,7 +1389,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } } - if constexpr(NumBatchToMerge > 1) + if constexpr(NumGroupsToMerge > 1) { // support only if whole M and N can be proccessed on one block if(!(GemmM <= MPerBlock && GemmN <= NPerBlock)) @@ -1400,7 +1400,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { return false; } - if(arg.Conv_G_ % NumBatchToMerge != 0) + if(arg.Conv_G_ % NumGroupsToMerge != 0) { return false; } @@ -1563,7 +1563,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " - << NumBatchToMerge + << NumGroupsToMerge << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f5a8d4e9f7..2ee17c5a0a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -86,7 +86,6 @@ __global__ void const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, - const index_t groups_count, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -101,14 +100,11 @@ __global__ void defined(__gfx94__)) // offset base pointer for each work-group - const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); - const index_t& num_blocks_per_n = groups_count; - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); - - const long_index_t e_batch_offset = + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); @@ -121,14 +117,14 @@ __global__ void DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); if constexpr(isMultiA || isMultiB) { AsPointer p_as_grid_grp; BsPointer p_bs_grid_grp; - const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); // compute_ptr_offset_of_n_ not need BatchStrideB so // in case of MultiA is false but isMultiB is true @@ -139,27 +135,27 @@ __global__ void static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); static_for<0, NumATensor, 1>{}([&](auto i) { - p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i]; + p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; }); } else { const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); static_for<0, 1, 1>{}( - [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; }); + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; }); } - const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); + const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); static_for<0, NumBTensor, 1>{}( - [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); GridwiseGemm::template Run( p_as_grid_grp, p_bs_grid_grp, p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, + p_e_grid + e_group_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -172,19 +168,19 @@ __global__ void } else { - const long_index_t a_batch_offset = + const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = + const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); const long_index_t a_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); GridwiseGemm::template Run( - p_as_grid + a_batch_offset + a_n_offset, - p_bs_grid + b_batch_offset, + p_as_grid + a_group_offset + a_n_offset, + p_bs_grid + b_group_offset, p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, + p_e_grid + e_group_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -200,7 +196,6 @@ __global__ void ignore = p_bs_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = groups_count; ignore = a_grid_desc_k0_m_k1; ignore = b_grid_desc_k0_n_k1; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; @@ -287,7 +282,8 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + index_t NumGroupsToMerge = 1> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle : public DeviceGroupedConvFwdMultipleABD= 1); + static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; @@ -319,7 +317,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr auto I3 = Number<3>{}; static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + TransformConvFwdToGemm{}; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; @@ -550,7 +548,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { static_for<0, NumATensor, 1>{}([&](auto i) { // Init compute_ptr_offset_of_groups_ for multiple AB - compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideA_(i) = + a_g_n_c_wis_strides[0] * NumGroupsToMerge; // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data // type is not tuple) @@ -578,7 +577,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle }); static_for<0, NumBTensor, 1>{}([&](auto i) { // Init compute_ptr_offset_of_groups_ for multiple AB - compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_(i) = + b_g_k_c_xs_strides[0] * NumGroupsToMerge; using DataType = remove_cvref_t>; // It is possible that one of the AB is a pointer and one is a tuple. @@ -598,8 +598,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } else { - compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideA_ = + a_g_n_c_wis_strides[0] * NumGroupsToMerge; + compute_ptr_offset_of_groups_.BatchStrideB_ = + b_g_k_c_xs_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; // p_as and p_bs are pointers @@ -616,7 +618,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_(i) = static_cast(p_ds[i]); // D batch stride - compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; @@ -624,7 +627,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_); }); - compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; // populate desc for Ds/E @@ -745,8 +748,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); - const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N; - const index_t gdz = 1; + const index_t gdy = arg.num_group_ / NumGroupsToMerge; + const index_t gdz = num_workgroups_per_Conv_N; const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); @@ -795,7 +798,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count as_grid_desc_ak0_m_ak1, bs_grid_desc_bk0_n_bk1, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, @@ -839,7 +841,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, @@ -871,6 +872,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; + const index_t G = arg.b_g_k_c_xs_lengths_[I0]; + const index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + // check device if(get_device_name() == "gfx908") { @@ -919,6 +924,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC()) + { + return false; + } + } + + if constexpr(NumGroupsToMerge > 1) + { + if(!(C == 1)) + { + return false; + } + if(G % NumGroupsToMerge != 0) + { + return false; + } + if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC()) + { + return false; + } + } // check vector access of A // FIXME: layout @@ -928,11 +969,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v) { - const index_t C = arg.a_g_n_c_wis_lengths_[2]; - + // Check access per C if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) { - return false; + // If not possible, check access per G + if(!(ABlockTransferSrcVectorDim == 1 && C == 1 && + is_NSpatialGK_GKSpatial_NSpatialGC() && + G % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } } } else @@ -949,8 +995,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v) { - const index_t C = arg.b_g_k_c_xs_lengths_[2]; - if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) { return false; @@ -974,8 +1018,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v || is_same_v) { - const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) { valid = false; @@ -1020,8 +1062,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v) { - const index_t K = arg.e_g_n_k_wos_lengths_[2]; - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) { return false; @@ -1172,7 +1212,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle << BBlockTransferSrcScalarPerVector << ", " << CDEBlockTransferScalarPerVector_NPerBlock << ", " << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle + << CShuffleNXdlPerWavePerShuffle << ", " + << NumGroupsToMerge << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp index c20e5d36f8..3ee02558f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -59,6 +59,22 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC() is_same_v; } +template +constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC() +{ + return is_NWGK_GKXC_NWGC() || + is_NHWGK_GKYXC_NHWGC() || + is_NDHWGK_GKZYXC_NDHWGC(); +} + +template +constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC() +{ + return is_GNWK_GKXC_GNWC() || + is_GNHWK_GKYXC_GNHWC() || + is_GNDHWK_GKZYXC_GNDHWC(); +} + template struct ComputePtrOffsetOfStridedBatch { diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index 158890d7a3..bc290d5641 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -27,7 +27,7 @@ template struct TransformConvBwdWeightToGemmV2 { @@ -45,7 +45,7 @@ struct TransformConvBwdWeightToGemmV2 const index_t BatchStride = output_strides[0]; const index_t WoStride = output_strides[4]; const auto KStride = Number<1>{}; - return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumBatchToMerge, K), + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumGroupsToMerge, K), make_tuple(WoStride, BatchStride, KStride)); } @@ -65,13 +65,13 @@ struct TransformConvBwdWeightToGemmV2 if constexpr(ConvBackwardWeightSpecialization == device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { - return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumBatchToMerge, C), + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumGroupsToMerge, C), make_tuple(WiStride, BatchStride, CStride)); } else { return make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, NumBatchToMerge, C), + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), make_tuple(NStride, HiStride, WiStride, BatchStride, CStride)); } } @@ -88,30 +88,30 @@ struct TransformConvBwdWeightToGemmV2 const auto KStride = weights_strides[1]; const auto XStride = weights_strides[4]; const auto BatchStride = weights_strides[0]; - // Add NumBatchToMerge for Batch+M dimension and, 1 as a placehorder + // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder // for Batch+N dimension const auto desc = make_naive_tensor_descriptor( - make_tuple(NumBatchToMerge, K, Y * X, 1, C), + make_tuple(NumGroupsToMerge, K, Y * X, 1, C), make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); - // Padd 1 to NumBatchToMerge + // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( desc, - make_tuple(make_pass_through_transform(NumBatchToMerge), + make_tuple(make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(K), make_pass_through_transform(Y * X), - make_pad_transform(1, 0, NumBatchToMerge - 1), + make_pad_transform(1, 0, NumGroupsToMerge - 1), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); // We need only matrices from diagonal. Xor returns 0 for the same // values. So if matrices is not on diagonal then it will be stored in padding. // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. - static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 || - NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 || - NumBatchToMerge == 64); + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)), + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), make_pass_through_transform(K), make_pass_through_transform(Y * X), make_pass_through_transform(C)), @@ -120,8 +120,8 @@ struct TransformConvBwdWeightToGemmV2 // Merge To M, N return transform_tensor_descriptor( unmerged_padded_desc, - make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)), - make_merge_transform(make_tuple(Y * X, NumBatchToMerge, C))), + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Y * X, NumGroupsToMerge, C))), make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -138,7 +138,7 @@ struct TransformConvBwdWeightToGemmV2 const index_t BatchStride = output_strides[0]; const index_t WoStride = output_strides[5]; const auto KStride = Number<1>{}; - return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumBatchToMerge, K), + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumGroupsToMerge, K), make_tuple(WoStride, BatchStride, KStride)); } @@ -160,13 +160,13 @@ struct TransformConvBwdWeightToGemmV2 if constexpr(ConvBackwardWeightSpecialization == device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { - return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumBatchToMerge, C), + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumGroupsToMerge, C), make_tuple(WiStride, BatchStride, CStride)); } else { return make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, NumBatchToMerge, C), + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride)); } } @@ -184,29 +184,29 @@ struct TransformConvBwdWeightToGemmV2 const auto KStride = weights_strides[1]; const auto XStride = weights_strides[5]; const auto BatchStride = weights_strides[0]; - // Add NumBatchToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension + // Add NumGroupsToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension const auto desc = make_naive_tensor_descriptor( - make_tuple(NumBatchToMerge, K, Z * Y * X, 1, C), + make_tuple(NumGroupsToMerge, K, Z * Y * X, 1, C), make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); - // Padd 1 to NumBatchToMerge + // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( desc, - make_tuple(make_pass_through_transform(NumBatchToMerge), + make_tuple(make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(K), make_pass_through_transform(Z * Y * X), - make_pad_transform(1, 0, NumBatchToMerge - 1), + make_pad_transform(1, 0, NumGroupsToMerge - 1), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); // We need only matrices from diagonal. Xor returns 0 for the same // values. So if matrices is not on diagonal then it will be stored in padding. // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. - static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 || - NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 || - NumBatchToMerge == 64); + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)), + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), make_pass_through_transform(K), make_pass_through_transform(Z * Y * X), make_pass_through_transform(C)), @@ -215,8 +215,8 @@ struct TransformConvBwdWeightToGemmV2 // Merge To M, N return transform_tensor_descriptor( unmerged_padded_desc, - make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)), - make_merge_transform(make_tuple(Z * Y * X, NumBatchToMerge, C))), + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Z * Y * X, NumGroupsToMerge, C))), make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -262,8 +262,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t InRightPadW = input_right_pads[1]; const index_t GemmKTotal = N * Ho * Wo; - const index_t GemmM = K * NumBatchToMerge; - const index_t GemmN = C * X * Y * NumBatchToMerge; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * X * Y * NumGroupsToMerge; const auto PadGemmM = MPerBlock - GemmM % MPerBlock; const auto PadGemmN = NPerBlock - GemmN % NPerBlock; @@ -286,7 +286,7 @@ struct TransformConvBwdWeightToGemmV2 out_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -302,7 +302,7 @@ struct TransformConvBwdWeightToGemmV2 in_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -324,7 +324,7 @@ struct TransformConvBwdWeightToGemmV2 out_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -341,7 +341,7 @@ struct TransformConvBwdWeightToGemmV2 make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(C)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), @@ -354,7 +354,7 @@ struct TransformConvBwdWeightToGemmV2 make_pass_through_transform(N), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(C)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), @@ -366,7 +366,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, NumBatchToMerge, C)), + make_tuple(make_merge_transform(make_tuple(Y, X, NumGroupsToMerge, C)), make_merge_transform(make_tuple(N, Ho, Wo))), make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); @@ -465,8 +465,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t InRightPadW = input_right_pads[2]; const index_t GemmKTotal = N * Do * Ho * Wo; - const index_t GemmM = K * NumBatchToMerge; - const index_t GemmN = C * Z * X * Y * NumBatchToMerge; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * Z * X * Y * NumGroupsToMerge; const auto PadGemmM = MPerBlock - GemmM % MPerBlock; const auto PadGemmN = NPerBlock - GemmN % NPerBlock; @@ -489,7 +489,7 @@ struct TransformConvBwdWeightToGemmV2 out_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -505,7 +505,7 @@ struct TransformConvBwdWeightToGemmV2 in_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -527,7 +527,7 @@ struct TransformConvBwdWeightToGemmV2 out_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -545,7 +545,7 @@ struct TransformConvBwdWeightToGemmV2 make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -567,7 +567,7 @@ struct TransformConvBwdWeightToGemmV2 make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -584,7 +584,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumBatchToMerge, C)), + make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumGroupsToMerge, C)), make_merge_transform(make_tuple(N, Do, Ho, Wo))), make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 2a02d25341..8dd6573015 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -25,11 +25,17 @@ __host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T return init; } -template +template struct TransformConvFwdToGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; static long_index_t calculate_element_space_size_impl(const std::array& lengths, @@ -117,13 +123,18 @@ struct TransformConvFwdToGemm const std::array& input_right_pads, const index_t N) { - const index_t C = a_g_n_c_wis_lengths[2]; + const index_t C = a_g_n_c_wis_lengths[I2]; - const index_t Wi = a_g_n_c_wis_lengths[3]; + const index_t Wi = a_g_n_c_wis_lengths[I3]; - const index_t Wo = c_g_n_k_wos_lengths[3]; + const index_t Wo = c_g_n_k_wos_lengths[I3]; - const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[I0]; + + const index_t GStride = a_g_n_c_wis_strides[I0]; + const index_t NStride = a_g_n_c_wis_strides[I1]; + const auto CStride = a_g_n_c_wis_strides[I2]; + const index_t WiStride = a_g_n_c_wis_strides[I3]; if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) @@ -132,41 +143,135 @@ struct TransformConvFwdToGemm N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NHoWo, C), + make_tuple(WiStride, CStride)); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + const index_t ConvDilationW = conv_filter_dilations[0]; - return in_gemmm_gemmk_desc; + const index_t InLeftPadW = input_left_pads[0]; + + const index_t InRightPadW = input_right_pads[0]; + if constexpr(NumGroupsToMerge == 1) + { + + const auto in_n_wi_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Wi), make_tuple(NStride, WiStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, NumGroupsToMerge), make_tuple(NStride, WiStride, GStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Pad0) { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); - const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - const auto in_n_wo_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), + make_tuple(NStride, WiStride, GStride, CStride)); - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - return in_gemmm_gemmk_desc; + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else { @@ -174,40 +279,67 @@ struct TransformConvFwdToGemm const index_t ConvDilationW = conv_filter_dilations[0]; const index_t InLeftPadW = input_left_pads[0]; const index_t InRightPadW = input_right_pads[0]; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - const auto in_n_wip_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), + make_tuple(NStride, WiStride, GStride, CStride)); - const auto in_n_x_wo_c_desc = transform_tensor_descriptor( - in_n_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{})); - return in_gemmm_gemmk_desc; + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } } @@ -242,51 +374,160 @@ struct TransformConvFwdToGemm const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[1]; + const index_t GStride = a_g_n_c_wis_strides[I0]; + const index_t NStride = a_g_n_c_wis_strides[I1]; + const index_t CStride = a_g_n_c_wis_strides[I2]; + const index_t HiStride = a_g_n_c_wis_strides[I3]; + const index_t WiStride = a_g_n_c_wis_strides[I4]; + if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { const index_t NHoWo = N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NHoWo, C), + make_tuple(WiStride, CStride)); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; - return in_gemmm_gemmk_desc; + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi), make_tuple(NStride, HiStride, WiStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, NumGroupsToMerge), + make_tuple(NStride, HiStride, WiStride, GStride)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Pad0) { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + return transform_tensor_descriptor( + in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, HiStride, WiStride, GStride, CStride)); - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - return in_gemmm_gemmk_desc; + return transform_tensor_descriptor( + in_n_ho_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else { @@ -302,42 +543,81 @@ struct TransformConvFwdToGemm const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadW = input_right_pads[1]; - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, HiStride, WiStride, GStride, CStride)); - return in_gemmm_gemmk_desc; + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } } @@ -375,6 +655,13 @@ struct TransformConvFwdToGemm const index_t ConvStrideH = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[2]; + const index_t GStride = a_g_n_c_wis_strides[I0]; + const index_t NStride = a_g_n_c_wis_strides[I1]; + const index_t CStride = a_g_n_c_wis_strides[I2]; + const index_t DiStride = a_g_n_c_wis_strides[I3]; + const index_t HiStride = a_g_n_c_wis_strides[I4]; + const index_t WiStride = a_g_n_c_wis_strides[I5]; + if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -382,49 +669,182 @@ struct TransformConvFwdToGemm N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), + make_tuple(WiStride, CStride)); + } + else + { + const auto in_gemmm_groups_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NDoHoWo, NumGroupsToMerge, C), + make_tuple(WiStride, GStride, CStride)); - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride)); + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; - return in_gemmm_gemmk_desc; + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi), make_tuple(NStride, DiStride, HiStride, WiStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Pad0) { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_do_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); - return in_gemmm_gemmk_desc; + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else { @@ -444,53 +864,107 @@ struct TransformConvFwdToGemm const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadW = input_right_pads[2]; - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); - const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), - make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); - return in_gemmm_gemmk_desc; + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } } @@ -499,9 +973,8 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto - MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */) + static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) { const index_t K = b_g_k_c_xs_lengths[1]; const index_t C = b_g_k_c_xs_lengths[2]; @@ -509,10 +982,54 @@ struct TransformConvFwdToGemm const index_t YX = ck::accumulate_n( b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - const auto wei_gemmn_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + const index_t GStride = b_g_k_c_xs_strides[I0]; + const index_t KStride = b_g_k_c_xs_strides[I1]; + const index_t CStride = b_g_k_c_xs_strides[I2]; - return wei_gemmn_gemmk_desc; + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + using FilterSizeNumType = + std::conditional_t, + std::conditional_t, Number<27>>>; + + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, FilterSizeNumType{})); + } + else + { + + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K, NumGroupsToMerge, FilterSizeNumType{}), + make_tuple(KStride, GStride, CStride)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), + make_pass_through_transform(FilterSizeNumType{})), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + } + else + { + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K, NumGroupsToMerge, YX * C), make_tuple(KStride, GStride, CStride)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), + make_pass_through_transform(YX * C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } } template < @@ -585,17 +1102,53 @@ struct TransformConvFwdToGemm { const index_t K = c_g_n_k_wos_lengths[2]; - const auto KStride = I1; + const index_t KStride = I1; const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; + const index_t GStride = c_g_n_k_wos_strides[0]; const index_t NHoWo = N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); - - return out_gemmm_gemmn_desc; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NHoWo, K), + make_tuple(WoStride, KStride)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, NumGroupsToMerge, K, 1), + make_tuple(WoStride, GStride, KStride, GStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_pass_through_transform(NHoWo), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } // for output bias diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 77d3728430..41303d2e95 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -40,10 +40,10 @@ template using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumBatch| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp new file mode 100644 index 0000000000..96baf6bb00 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd3x3 = ConvolutionForwardSpecialization::Filter3x3; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index ec5bd785a3..0233d6d85c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -17,6 +17,7 @@ #endif #ifdef CK_USE_XDL #include "grouped_convolution_forward_xdl.inc" +#include "grouped_convolution_forward_xdl_merged_groups.inc" #include "grouped_convolution_forward_comp_xdl.inc" #include "grouped_convolution_forward_mem_inter_xdl.inc" #include "grouped_convolution_forward_mem_intra_xdl.inc" @@ -199,6 +200,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( op_ptrs); @@ -212,6 +215,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( op_ptrs); @@ -227,6 +232,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( op_ptrs); @@ -284,6 +291,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( op_ptrs); @@ -338,6 +347,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( op_ptrs); @@ -353,6 +364,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc new file mode 100644 index 0000000000..fe09d3f6a7 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 4e002c7222..170625a6a0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -9,6 +9,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + # merged groups + # NHWGC, GKYXC, NHWGK + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp #mem # NHWGC, GKYXC, NHWGK xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..6fa4bc6e46 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..9fa56f48c7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 0000000000..e226dae975 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index e24dbcd2c1..5be6672723 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -9,6 +9,10 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..cf1fcec985 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..bea62892d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000..de44725413 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 21fe7992ac..1bfc183135 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -104,6 +104,7 @@ TYPED_TEST(TestGroupedConvndFwd1d, Test1D) this->conv_params.push_back({1, 2, 32, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 96, 1, 1, 1, {3}, {512}, {1}, {1}, {1}, {1}}); this->template Run<1>(); } @@ -119,6 +120,8 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D) this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->template Run<2>(); } @@ -137,6 +140,8 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D) {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {4, 30, 160}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } @@ -144,6 +149,9 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases) { // Case larger than 2GB this->conv_params.push_back( - {2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {0, 0}, {0, 0}, {0, 0}}); + {2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back( + {2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}}); this->template Run<2>(); }