diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index 6af8ac6488..1823d4fc0a 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -92,7 +92,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_params = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 89a304fda4..db2426518a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -74,7 +74,10 @@ template + InMemoryDataOperationEnum OutElementOp, + bool HasMainKBlockLoopInAllGemm, + bool NoMainKBlockLoopInAllGemm, + bool CTranspose> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -101,16 +104,21 @@ __global__ void const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); const long_index_t e_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; + const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); @@ -141,11 +149,11 @@ __global__ void group_id = index_t((left + right) / 2); } - if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) { - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset, + p_b_grid + b_batch_offset + b_n_offset, p_ds_grid_grp, p_e_grid + e_batch_offset + e_n_offset, p_shared, @@ -162,22 +170,44 @@ __global__ void } else { - GridwiseGemm::template Run( - p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, - gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, - gemm_kernel_args[group_id].block_2_ctile_map_, - KBatch, - k_idx); + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } + else + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } } #else ignore = p_a_grid; @@ -278,7 +308,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // implementation we can avoid copy data to workspace before kernel launch since number of // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then // we run this kernel in the loop. - static constexpr index_t MaxGroupedGemmGroupsNum = 32; + static constexpr index_t MaxGroupedGemmGroupsNum = + ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0 + ? 1 + : 32; using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; @@ -296,24 +330,40 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - using ALayoutAfterTranspose = - std::conditional_t(), - tensor_layout::convolution::NHWGK, - std::conditional_t(), - tensor_layout::convolution::NDHWGK, - ALayout>>; - using BLayoutAfterTranspose = - std::conditional_t(), - tensor_layout::convolution::GKYXC, - std::conditional_t(), - tensor_layout::convolution::GKZYXC, - BLayout>>; - using ELayoutAfterTranspose = - std::conditional_t(), - tensor_layout::convolution::NHWGC, - std::conditional_t(), - tensor_layout::convolution::NDHWGC, - ELayout>>; + static constexpr bool isATensorColMajor = + (ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) && + (ABlockTransferSrcVectorDim == 1) && + (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool NeedTransposeKernel = + (isATensorColMajor == false) && (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool CTranspose = + (NeedTransposeKernel == false) && (is_same_v || + is_same_v); + + using ALayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGK, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGK, + ALayout>>; + using BLayoutAfterTranspose = std::conditional_t< + is_NGCHW_GKCYX_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::GKYXC, + std::conditional_t() && + NeedTransposeKernel, + tensor_layout::convolution::GKZYXC, + BLayout>>; + using ELayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGC, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGC, + ELayout>>; using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + EDataType, + 1, + index_t, + CTranspose>; static auto GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) @@ -357,15 +410,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DLayout, true, /*SplitConvN*/ ABDataType, - DDataType>; + DDataType, + 1, /*index_t NumGroupsToMerge = 1,*/ + index_t, /* typename IndexType = */ + CTranspose>; return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N(); }, Number{}); const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N(); - - return make_tuple( - a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + if constexpr(CTranspose) + { + return make_tuple( + b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + else + { + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } } // GridwiseGemm @@ -383,13 +446,34 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType + +#define GridwiseGemmCTransposeTemplateParameters \ + ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ + NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmCTranspose = std::conditional_t< + CTranspose, + GridwiseGemmMultipleD_xdl_cshuffle, + GridwiseGemm>; template static auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) { - return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); } template @@ -419,13 +503,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); // block-to-e-tile map - using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using Block2ETileMap = + decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})); using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; @@ -630,14 +715,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 sizeof(EDataType); std::array a_g_n_k_wos_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, - a_g_n_k_wos_strides); + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides) + : a_g_n_k_wos_strides; std::array b_g_k_c_xs_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides) + : b_g_k_c_xs_strides; std::array e_g_n_c_wis_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths, - e_g_n_c_wis_strides); + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides) + : e_g_n_c_wis_strides; // populate Ds pointer static_for<0, NumDTensor, 1>{}([&](auto i) { @@ -737,12 +825,27 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 conv_N_per_block_ = conv_to_gemm_transform_.N_; - const auto a_grid_desc_ak0_m_ak1 = - conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); - - const auto b_grid_desc_bk0_n_bk1 = - conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + const auto a_grid_desc_ak0_m_ak1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + else + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + }(); + const auto b_grid_desc_bk0_n_bk1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + else + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + }(); DsGridDesc_M_N ds_grid_desc_m_n; // populate Ds desc @@ -764,7 +867,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DLayout, true, /*SplitConvN*/ ABDataType, - DDataType>; + DDataType, + 1, + index_t, + CTranspose>; ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{ a_g_n_k_wos_lengths, a_g_n_k_wos_strides_transposed, @@ -810,14 +916,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto GemmK = a_grid_desc_m_k.GetLength(I1); const bool HasMainKBlockLoop = - GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, k_batch_); + GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_); gemm_kernel_args_[gemms_count_ / MaxGroupedGemmGroupsNum][gemms_count_ % MaxGroupedGemmGroupsNum] = GemmArgs{a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - GridwiseGemm:: + GridwiseGemmCTranspose:: MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n), MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -851,8 +957,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_; - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { // Use not modified base strides a_in_transpose_desc_ = @@ -892,8 +997,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::size_t GetWorkspaceATensorSizeBytes() const { - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t a_acum = ck::accumulate_n( a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -908,8 +1012,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::size_t GetWorkspaceBTensorSizeBytes() const { - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t b_acum = ck::accumulate_n( b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -924,8 +1027,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::size_t GetWorkspaceETensorSizeBytes() const { - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t e_accum = ck::accumulate_n( e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -1030,24 +1132,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; EDataType* p_e_grid = arg.p_e_grid_; - - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { - p_a_grid = type_convert(arg.p_workspace_); - p_e_grid = - type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / - sizeof(EDataType); - } + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) - { - p_b_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + } } - for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size(); gemm_set_id++) { @@ -1067,42 +1170,111 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } }; - auto launch_kernel = [&]() { - const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - typename GridwiseGemm::DsGridPointer, - EDataType, - MaxGroupedGemmGroupsNum, - GemmArgs, - AElementwiseOp, - BElementwiseOp, - CDEElementwiseOp, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - ElementOp>; + bool has_loop_in_all_gemm = true; + bool no_loop_in_all_gemm = true; + for(auto i = 0; i < gemms_count_for_set; i++) + { + has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_; + no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_; + } - return launch_and_time_kernel_with_preprocess(stream_config, - clear_workspace, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - arg.p_ds_grid_, - p_e_grid, - gemm_kernel_args, - gemms_count_for_set, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_, - arg.k_batch_); + auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool no_main_loop = no_main_k_block_loop.value; + if constexpr(CTranspose) + { + const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< + GridwiseGemmCTranspose, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + MaxGroupedGemmGroupsNum, + GemmArgs, + BElementwiseOp, + AElementwiseOp, + CDEElementwiseOp, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop, + CTranspose>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_b_grid, + p_a_grid, + arg.p_ds_grid_, + p_e_grid, + gemm_kernel_args, + gemms_count_for_set, + arg.b_element_op_, + arg.a_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + arg.k_batch_); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + MaxGroupedGemmGroupsNum, + GemmArgs, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop, + CTranspose>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + arg.p_ds_grid_, + p_e_grid, + gemm_kernel_args, + gemms_count_for_set, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + arg.k_batch_); + } }; - - ave_time += launch_kernel(); + if(has_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(no_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } } return ave_time; @@ -1116,9 +1288,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { arg.Print(); } + // Transpose from NGKHW to NHWGK - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { EDataType* p_e_in_grid = type_convert(arg.p_workspace_) + @@ -1208,8 +1380,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // Transpose from NHWGC to NGCHW - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const index_t grid_size = arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( @@ -1284,10 +1455,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } - const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; - const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; - const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; - + const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); // Specifialization if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) @@ -1307,15 +1481,30 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 if constexpr(is_same_v || is_same_v || is_same_v || - is_same_v || - is_same_v || - is_same_v) + is_same_v || NeedTransposeKernel) { if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) { return false; } } + else if(is_same_v || + is_same_v) + { + static_assert(NeedTransposeKernel == false); + + if constexpr(ABlockTransferSrcScalarPerVector != 1) + { + if(ABlockTransferSrcVectorDim != 1) + { + return false; + } + if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + } else { return false; @@ -1351,10 +1540,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 is_same_v || is_same_v) { - // vector load D matrix from global memory - if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + if(CTranspose == false) { - ds_valid = false; + // vector load D matrix from global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + ds_valid = false; + } } } else @@ -1376,10 +1575,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 is_same_v || is_same_v) { - // vector store C matrix into global memory - if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + if(CTranspose == false) { - return false; + // vector store C matrix into global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } } } else @@ -1390,7 +1599,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // Gridwise GEMM size for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++) { - if(!GridwiseGemm::CheckValidity( + if(!GridwiseGemmCTranspose::CheckValidity( arg.a_grid_desc_m_k_container_[i], arg.b_grid_desc_n_k_container_[i], arg.ds_grid_desc_m_n_container_[i], @@ -1403,8 +1612,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index a191c75099..977c622f06 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -30,7 +30,8 @@ template < typename ADataType = float, typename CDataType = float, index_t NumGroupsToMerge = 1, - typename IndexType = index_t> + typename IndexType = index_t, + bool CTranspose = false> struct TransformConvBwdDataToGemm_v1 { private: @@ -555,6 +556,41 @@ struct TransformConvBwdDataToGemm_v1 return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_)); } } + else if constexpr(is_same_v) + { + // assume packed + static_assert(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0); + + const auto out_gemm_raw_grid_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_ * Wo_, K_), make_tuple(NStrideTensorA_, I1, KStrideTensorA_)); + + return transform_tensor_descriptor( + out_gemm_raw_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(is_same_v) + { + // assume packed + static_assert(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0); + + const auto out_gemm_raw_grid_desc = + make_naive_tensor_descriptor(make_tuple(N_, Do_ * Ho_ * Wo_, K_), + make_tuple(NStrideTensorA_, I1, KStrideTensorA_)); + + return transform_tensor_descriptor( + out_gemm_raw_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } else { throw std::runtime_error("wrong! unsupported layout: " + ALayout::name()); @@ -608,7 +644,9 @@ struct TransformConvBwdDataToGemm_v1 (is_same_v || is_same_v || is_same_v || - is_same_v), + is_same_v || + is_same_v || + is_same_v), bool>::type = false> __host__ __device__ auto MakeADescriptor_AK0_M_AK1() const { @@ -848,16 +886,16 @@ struct TransformConvBwdDataToGemm_v1 } } - template || - is_same_v), - bool>::type = false> + template < + typename BLayout_ = BLayout, + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const { - // assume packed - // k_y_x_c for 2d or k_z_y_x_c for 3d - const auto wei_grid_desc = MakeWeiGridDesc(); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: @@ -886,6 +924,12 @@ struct TransformConvBwdDataToGemm_v1 } else { + // assume packed + // k_y_x_c for 2d or k_z_y_x_c for 3d + static_assert(is_same_v || + is_same_v); + const auto wei_grid_desc = MakeWeiGridDesc(); + // GemmK is different for each GEMM const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); @@ -1059,6 +1103,7 @@ struct TransformConvBwdDataToGemm_v1 bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { + static_assert(CTranspose == false); // assume strided // n_hi_wi_c for 2d n_di_hi_wi_c for 3d const auto in_grid_desc = MakeInGridDesc(); @@ -1314,6 +1359,48 @@ struct TransformConvBwdDataToGemm_v1 } } + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + const auto in_grid_desc = make_naive_tensor_descriptor( + make_tuple(N_, C_, Di_ * Hi_ * Wi_), make_tuple(NStrideTensorC_, CStrideTensorC_, I1)); + + static_assert(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0); + + if constexpr(CTranspose) + { + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(C_), + make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmNPerBlock, GemmMPerBlock), + Sequence{}); + } + else + { + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + } + } // for input bias template ; +template +using device_grouped_conv_bwd_data_xdl_f16_nchw_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 8, 1, 32>, 2> + // clang-format on + >; // bf16_bf16_f32_bf16 template {}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_nchw_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp index bada2507c2..b1043260ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -32,6 +32,14 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( Empty_Tuple, NGCDHW, ConvBwdDataDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_nchw_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance