diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 6d04835b21..6d2988ba24 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -77,7 +77,8 @@ template + bool isMultiB, + bool CTranspose> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -171,17 +172,22 @@ __global__ void } else { - const long_index_t a_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); const long_index_t b_group_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t a_group_offset = + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); GridwiseGemm::template Run( p_as_grid + a_group_offset + a_n_offset, - p_bs_grid + b_group_offset, + p_bs_grid + b_group_offset + b_n_offset, p_ds_grid_grp, p_e_grid + e_group_offset + e_n_offset, p_shared, @@ -335,12 +341,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; + static constexpr bool isATensorColMajor = + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) && + (ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) && + (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool NeedTransposeKernel = + (isATensorColMajor == false) && (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) && + (is_same_v || + is_same_v); + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + NumGroupsToMerge, + index_t, + CTranspose>; static constexpr index_t ClusterLengthNPerBlock = CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); @@ -361,9 +383,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW(), + is_NGCHW_NGKHW() && NeedTransposeKernel, ctc::NHWGC, - std::conditional_t(), ctc::NDHWGC, ALay>>; + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGC, + ALay>>; const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -379,9 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW(), + is_NGCHW_NGKHW() && NeedTransposeKernel, ctc::GKYXC, - std::conditional_t(), ctc::GKZYXC, BLay>>; + std::conditional_t() && NeedTransposeKernel, + ctc::GKZYXC, + BLay>>; const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -397,17 +423,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW(), + is_NGCHW_NGKHW() && NeedTransposeKernel, ctc::NHWGK, - std::conditional_t(), ctc::NDHWGK, ELay>>; + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGK, + ELay>>; const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); - - const auto out_gemmm_gemmn_desc = - matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); - - return out_gemmm_gemmn_desc; + if constexpr(CTranspose) + { + constexpr auto matrix_padder_trans = + MatrixPadder{NPerBlock, MPerBlock, KPerBlock}; + return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + } + else + { + return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + } } // Shape of Ds and E must be aligned. Strides can be different. @@ -471,11 +504,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType, DoElementwiseBeforeCShuffle + +#define GridwiseGemmCTransposeTemplateParameters \ + GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \ + MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType, DoElementwiseBeforeCShuffle + // Use appropriate gridwise gemm using GridwiseGemm = std::conditional_t< isMultiA || isMultiB, GridwiseGemmMultipleABD_xdl_cshuffle, GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemmCTranspose = std::conditional_t< + CTranspose, + GridwiseGemmMultipleD_xdl_cshuffle, + GridwiseGemm>; // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. using APointers = @@ -497,15 +551,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; using NGCHWTransposeDescType = @@ -612,16 +667,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, - a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides)}, + a_g_n_c_wis_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides) + : a_g_n_c_wis_strides}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, - b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( - b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + b_g_k_c_xs_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides) + : b_g_k_c_xs_strides}, ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, - e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( - e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, + e_g_n_k_wos_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides) + : e_g_n_k_wos_strides}, conv_filter_strides_{conv_filter_strides}, conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, @@ -651,7 +712,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{ + GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, @@ -783,24 +845,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } else { - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + bool valid = false; + if constexpr(CTranspose) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); + valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_, + a_grid_desc_m_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + else + { + valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + if(valid) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_); } } - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { // Use not modified base strides a_in_transpose_desc_ = @@ -835,8 +907,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle std::size_t GetWorkspaceATensorSizeBytes() const { - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t a_acum = ck::accumulate_n( a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -851,8 +922,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle std::size_t GetWorkspaceBTensorSizeBytes() const { - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t b_acum = ck::accumulate_n( b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -867,8 +937,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle std::size_t GetWorkspaceETensorSizeBytes() const { - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t e_accum = ck::accumulate_n( e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -1007,7 +1076,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ComputePtrOffsetOfStridedBatch, has_main_loop, isMultiA, - isMultiB>; + isMultiB, + CTranspose>; return launch_and_time_kernel( stream_config, @@ -1035,68 +1105,118 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle const ADataType* p_a_grid = arg.p_as_grid_.At(I0); const BDataType* p_b_grid = arg.p_bs_grid_.At(I0); EDataType* p_e_grid = arg.p_e_grid_; - - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) + if constexpr(NeedTransposeKernel) { - p_a_grid = type_convert(arg.p_workspace_); - p_b_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); - p_e_grid = type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + - arg.GetWorkspaceBTensorSizeBytes()) / - sizeof(EDataType); - } - else if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) - { - p_a_grid = type_convert(arg.p_workspace_); - p_e_grid = type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + - arg.GetWorkspaceBTensorSizeBytes()) / - sizeof(EDataType); + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + p_e_grid = type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + else if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } } - const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< - GridwiseGemm, - const ADataType*, - const BDataType*, - typename GridwiseGemm::DsGridPointer, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - Block2ETileMap, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - has_main_loop, - isMultiA, - isMultiB>; + if constexpr(CTranspose) + { + const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemmCTranspose, + const BDataType*, + const ADataType*, + typename GridwiseGemm::DsGridPointer, + EDataType, + BElementwiseOperation, + AElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB, + CTranspose>; - return launch_and_time_kernel( - stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - arg.p_ds_grid_, - p_e_grid, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_etile_map_, - arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_); + return launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_b_grid, + p_a_grid, + arg.p_ds_grid_, + p_e_grid, + arg.b_element_op_, + arg.a_element_op_, + arg.cde_element_op_, + arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_ak0_m_ak1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + const ADataType*, + const BDataType*, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB, + CTranspose>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + arg.p_ds_grid_, + p_e_grid, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + } } }; @@ -1114,8 +1234,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { float avg_time = 0.f; - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const index_t a_grid_size = arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( @@ -1166,8 +1285,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle avg_time += RunGemm(arg, stream_config); - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const index_t grid_size = arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( @@ -1215,9 +1333,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; - const index_t G = arg.b_g_k_c_xs_lengths_[I0]; - const index_t K = arg.b_g_k_c_xs_lengths_[I1]; - const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + const index_t G = arg.b_g_k_c_xs_lengths_[I0]; + const index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); // check device if(get_device_name() == "gfx908") @@ -1310,7 +1430,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v) + NeedTransposeKernel) { // Check access per C if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) @@ -1326,6 +1446,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } } + else if constexpr(is_same_v || is_same_v) + { + static_assert(NeedTransposeKernel == false); + static_assert(NumGroupsToMerge == 1); + + if constexpr(ABlockTransferSrcScalarPerVector != 1) + { + if(ABlockTransferSrcVectorDim != 1) + { + return false; + } + if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + } else { return false; @@ -1350,7 +1487,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { return false; } - // check vector access of Ds bool valid = true; @@ -1396,8 +1532,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } }); - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { @@ -1409,8 +1544,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return false; } - const index_t input_spatial_acum = ck::accumulate_n( - arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); const index_t output_spatial_acum = ck::accumulate_n( arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); @@ -1457,9 +1590,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v || is_same_v) { - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + if(CTranspose == false) { - return false; + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } } } else @@ -1483,11 +1629,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } else { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(CTranspose) + { + return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } } } diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index c291f3994c..92b48c44b3 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -19,7 +19,8 @@ template + typename IndexType = index_t, + bool CTranspose = false> struct TransformConvFwdToGemm { private: @@ -1253,6 +1254,83 @@ struct TransformConvFwdToGemm } } + template , + bool>::type = false> + __host__ __device__ auto MakeADescriptor_M_K() const + { + static_assert(NumGroupsToMerge == 1); + static_assert(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0); + + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template , + bool>::type = false> + __host__ __device__ auto MakeADescriptor_M_K() const + { + static_assert(NumGroupsToMerge == 1); + static_assert(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0); + + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template , + bool>::type = false> + __host__ __device__ auto MakeADescriptor_M_K() const + { + static_assert(NumGroupsToMerge == 1); + static_assert(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0); + + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_ * Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ auto MakeBDescriptor_N_K() const + { + static_assert(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 || + ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0); + static_assert(NumGroupsToMerge == 1); + return make_naive_tensor_descriptor_packed(make_tuple(K_, C_)); + } + template || is_same_v || @@ -1338,8 +1416,16 @@ struct TransformConvFwdToGemm bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { - return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), - make_tuple(I0, KStrideTensorC_)); + if constexpr(CTranspose) + { + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_), + make_tuple(KStrideTensorC_, I0)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } } template < @@ -1350,8 +1436,16 @@ struct TransformConvFwdToGemm bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { - return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), - make_tuple(I0, KStrideTensorC_)); + if constexpr(CTranspose) + { + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_), + make_tuple(KStrideTensorC_, I0)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } } template < @@ -1362,12 +1456,21 @@ struct TransformConvFwdToGemm bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { - return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), - make_tuple(I0, KStrideTensorC_)); + if constexpr(CTranspose) + { + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_), + make_tuple(KStrideTensorC_, I0)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } } template || is_same_v || @@ -1375,6 +1478,7 @@ struct TransformConvFwdToGemm bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { + static_assert(CTranspose == false); const IndexType NDoHoWo = N_ * Wo_; if constexpr(NumGroupsToMerge == 1) { @@ -1429,6 +1533,7 @@ struct TransformConvFwdToGemm bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { + static_assert(CTranspose == false); const IndexType NDoHoWo = N_ * Ho_ * Wo_; if constexpr(NumGroupsToMerge == 1) { @@ -1486,7 +1591,7 @@ struct TransformConvFwdToGemm bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { - + static_assert(CTranspose == false); const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_; if constexpr(NumGroupsToMerge == 1) { @@ -1536,6 +1641,101 @@ struct TransformConvFwdToGemm } } + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(NumGroupsToMerge == 1); + auto n_k_wo_desc = make_naive_tensor_descriptor( + make_tuple(N_, K_, Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1)); + if constexpr(CTranspose) + { + return transform_tensor_descriptor( + n_k_wo_desc, + make_tuple(make_pass_through_transform(K_), + make_merge_transform(make_tuple(N_, Wo_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor(n_k_wo_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(NumGroupsToMerge == 1); + auto n_k_howo_desc = make_naive_tensor_descriptor( + make_tuple(N_, K_, Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1)); + if constexpr(CTranspose) + { + return transform_tensor_descriptor( + n_k_howo_desc, + make_tuple(make_pass_through_transform(K_), + make_merge_transform(make_tuple(N_, Ho_ * Wo_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + n_k_howo_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(NumGroupsToMerge == 1); + auto n_k_dohowo_desc = make_naive_tensor_descriptor( + make_tuple(N_, K_, Do_ * Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1)); + + if constexpr(CTranspose) + { + return transform_tensor_descriptor( + n_k_dohowo_desc, + make_tuple(make_pass_through_transform(K_), + make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + n_k_dohowo_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } IndexType N_; IndexType Di_, Hi_, Wi_; IndexType Do_, Ho_, Wo_; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index d6b695360b..c641019b70 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -179,6 +179,38 @@ using device_grouped_conv_fwd_xdl_f16_instances = std::tuple< // clang-format on >; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f16_nchw_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, + // 32x32 instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + // 16x16 instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + template , Shards, ShardIndex>{}); + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); } } // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp index 78d1747548..10267573da 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp @@ -31,6 +31,14 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_nchw_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp index 0ddf5bfa48..9795b6a096 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp @@ -47,6 +47,15 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instances( Empty_Tuple, NGKDHW, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_nchw_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); } } // namespace instance