From 6ad73cd0cdb9ee082df3ddaf920b91a0695f167e Mon Sep 17 00:00:00 2001 From: kiefer Date: Sun, 24 Aug 2025 12:44:01 +0000 Subject: [PATCH] Add CTranspose optimization for NCHW cases just like in xdl cshuffle non-v3 device implementation. --- ...conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp | 779 ++++++++---------- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 5 +- 2 files changed, 364 insertions(+), 420 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index 0e484e9a1c..afc8e360c6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -339,22 +339,30 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3; - static constexpr bool isMultiA = is_detected::value; - static constexpr bool isMultiB = is_detected::value; - static constexpr bool isMultiD = DsDataType::Size() > 0; + static constexpr index_t NumGroupsToMerge = 1; // TODO: Implement merge groups. + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiAB = isMultiA || isMultiB; + static constexpr bool isMultiD = DsDataType::Size() > 0; // TODO: This will never be true pretty much. static constexpr bool isMultiABD = isMultiA && isMultiB && isMultiD; + // NGCHW is not supported for multiAB. + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || + !(isMultiA || isMultiB)); + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + // TODO: This parameter is no longer supported by Gridwise! // static constexpr bool DoElementwiseBeforeCShuffle = // !isMultiD && is_same_v && // !is_same_v; - static constexpr index_t NumATensor = GetNumABTensors(); - static constexpr index_t NumBTensor = GetNumABTensors(); - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -362,6 +370,20 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; + static constexpr bool isATensorColMajor = + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) && + (ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) && + (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool NeedTransposeKernel = + (isATensorColMajor == false) && (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) && + (is_same_v || + is_same_v); + // Generate vector size for C & Ds using CDEBlockTransferScalarPerVectors = typename uniform_sequence_gen; + EDataType, + NumGroupsToMerge, + index_t, + CTranspose>; using ComputePtrOffset = ComputePtrOffsetOfStridedBatch; - // TODO: Original xdl non-v3 chuffle had an isATensorColMajor parameter that had some very - // specific conditions and some interplay with the decision to use a transpose kernel. - // We need to duplicate this logic for proper nchw instance support. - - // TODO: Original xdl non-v3 chuffle had a CTranspose parameter that had some very - // specific conditions and decided whether to use CTranspose in the ConvToGemm transformers. - // We need to duplicate this logic for proper nchw instance support. - static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; @@ -404,11 +421,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW(), // TODO: Removed weight layout check! + is_NGCHW_NGKHW() && NeedTransposeKernel, ctc::NHWGC, - std::conditional_t(), // TODO: Removed - // weight layout - // check! + std::conditional_t() && NeedTransposeKernel, ctc::NDHWGC, ALay>>; @@ -436,11 +451,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW(), // TODO: Removed weight layout check! + is_NGCHW_NGKHW() && NeedTransposeKernel, ctc::GKYXC, - std::conditional_t(), // TODO: Removed - // weight layout - // check! + std::conditional_t() && NeedTransposeKernel, ctc::GKZYXC, BLay>>; @@ -468,21 +481,25 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW(), // TODO: Removed weight layout check! + is_NGCHW_NGKHW() && NeedTransposeKernel, ctc::NHWGK, - std::conditional_t(), // TODO: Removed - // weight layout - // check! + std::conditional_t() && NeedTransposeKernel, ctc::NDHWGK, ELay>>; const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); - const auto out_gemmm_gemmn_desc = - matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); - - return out_gemmm_gemmn_desc; + if constexpr(CTranspose) + { + constexpr auto matrix_padder_trans = + MatrixPadder{NPerBlock, MPerBlock, KPerBlock}; + return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + } + else + { + return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + } } // Shape of Ds and E must be aligned. Strides can be different. @@ -553,6 +570,78 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 // TODO: Previously available template param DoElementwiseBeforeCShuffle! + // In case of CTranspose we swap the following template parameters: + // DataType, ElementWiseOp, PerBlock, K1, PerWmma, Repeat, All block transfer params. + using GridwiseGemmSwappedParams = GridwiseGemm_wmma_cshuffle_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + + DsLayout, + tensor_layout::gemm::RowMajor, + + Tuple, + Tuple, + + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + + BElementwiseOperation, + AElementwiseOperation, + + CDEElementwiseOperation, + GemmSpec, + BlockSize, + + NPerBlock, + MPerBlock, + + KPerBlock, + + BK1, + AK1, + + NPerWmma, + MPerWmma, + + NRepeat, + MRepeat, + + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun + BBlockLdsExtraN, + + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun + ABlockLdsExtraM, + + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + + AComputeDataType, // TODO: swap these? + BComputeDataType, + + false, // PermuteA + false>; // PermuteB + + using GridwiseGemmCTranspose = + std::conditional_t; + // desc for problem definition constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using EGridDesc_M_N = @@ -560,10 +649,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 using DsGridDesc_M_N = remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}, 1, 1))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}, 1, 1))>; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -673,25 +762,22 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 p_ds_grid_{p_ds}, p_e_grid_{static_cast(p_e)}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, - a_g_n_c_wis_strides_{ - conv_ngchw_to_nhwgc_transformer - .TransposeInOutStrides( // TODO: Originally only used for transpose cases - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides)}, + a_g_n_c_wis_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides) + : a_g_n_c_wis_strides}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, - b_g_k_c_xs_strides_{ - conv_ngchw_to_nhwgc_transformer - .TransposeWeiStrides( // TODO: Originally only used for transpose cases - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, + b_g_k_c_xs_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides) + : b_g_k_c_xs_strides}, ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, - e_g_n_k_wos_strides_{ - conv_ngchw_to_nhwgc_transformer - .TransposeInOutStrides( // TODO: Originally only used for transpose cases - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + e_g_n_k_wos_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides) + : e_g_n_k_wos_strides}, conv_filter_strides_{conv_filter_strides}, conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, @@ -722,9 +808,14 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 cde_element_op_{cde_element_op} { // A/B/E Batch/N Stride - compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0]; - compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0]; - compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_; + compute_ptr_offset_of_groups_.BatchStrideA_ = + CTranspose ? b_g_k_c_xs_strides_[0] : a_g_n_c_wis_strides_[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = + CTranspose ? a_g_n_c_wis_strides_[0] : b_g_k_c_xs_strides_[0]; + compute_ptr_offset_of_n_.BatchStrideA_ = + CTranspose ? 0 : a_g_n_c_wis_strides_[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideB_ = + CTranspose ? a_g_n_c_wis_strides_[1] * conv_N_per_block_ : 0; // p_as and p_bs are pointers p_a_grid_ = static_cast(p_as); @@ -757,10 +848,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0]; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; - if constexpr(is_NGCHW_NGKHW() || // TODO: removed weight - // layout check - is_NGCDHW_NGKDHW()) // TODO: removed weight - // layout check + if constexpr(NeedTransposeKernel) { // Use not modified base strides a_in_transpose_desc_ = @@ -801,23 +889,24 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I1); - const auto MBlock = GridwiseGemm::CalculateMBlock(GemmM); - const auto NBlock = GridwiseGemm::CalculateNBlock(GemmN); + const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN) + : GridwiseGemmCTranspose::CalculateMBlock(GemmM); + const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM) + : GridwiseGemmCTranspose::CalculateNBlock(GemmN); ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n_, MBlock, NBlock); e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + GridwiseGemmCTranspose::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n_, MBlock, NBlock); } } std::size_t GetWorkspaceATensorSizeBytes() const { - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t a_acum = ck::accumulate_n( a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -830,14 +919,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } } - // TODO: This might be dubious in the case there we need to transpose A but not B. Need to + // TODO: This might use unnecessary memory when we need to transpose A but not B. Need to // check how this is used. std::size_t GetWorkspaceBTensorSizeBytes() const { - if constexpr(is_NGCHW_NGKHW() || // TODO: removed weight - // layout check - is_NGCDHW_NGKDHW()) // TODO: removed weight - // layout check + if constexpr(NeedTransposeKernel) { const long_index_t b_acum = ck::accumulate_n( b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -852,8 +938,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 std::size_t GetWorkspaceETensorSizeBytes() const { - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { const long_index_t e_accum = ck::accumulate_n( e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -948,6 +1033,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 arg.Print(); } + printf("\033[035mCTranspose %d\033[0m\n", CTranspose); + float ave_time = 0; constexpr index_t minimum_occupancy = @@ -964,15 +1051,18 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 index_t gdx, gdy, gdz; // TODO: Do we want to support kbatch ?? std::tie(gdx, gdy, gdz) = - GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); + CTranspose + ? GridwiseGemmCTranspose::CalculateGridSize(GemmN, GemmM, I1 /*arg.KBatch*/) + : GridwiseGemmCTranspose::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); // TODO: Suspicious use of grid dims. Check run function. gdy = arg.num_group_; gdz = num_workgroups_per_Conv_N; // TODO: does this need to be updated for splitK? - index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; + const bool has_main_k_block_loop = + GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(K_split); // TODO: need arg.p_as_grid_? const ADataType* p_a_grid = arg.p_a_grid_; @@ -980,90 +1070,67 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 EDataType* p_e_grid = arg.p_e_grid_; // Transpose A and B, or just A. - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) + if constexpr(NeedTransposeKernel) { - p_a_grid = type_convert(arg.p_workspace_); - p_b_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); - p_e_grid = - type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / - sizeof(EDataType); - } - else if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) - { - p_a_grid = type_convert(arg.p_workspace_); - p_e_grid = - type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + - arg.GetWorkspaceBTensorSizeBytes()) / // TODO: This offset might be unnecessary - // if we are not doing a B transpose. - sizeof(EDataType); + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + else if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + + arg.GetWorkspaceBTensorSizeBytes()) / // TODO: This offset might be + // unnecessary if we are not + // doing a B transpose. + sizeof(EDataType); + } } // TODO: Pretty much ok, but need p_as_grid and p_bs_grid static_assert(NumATensor == 1, "Num A Tensor should be 1\n"); static_assert(NumBTensor == 1, "Num B Tensor should be 1\n"); - typename GridwiseGemm::Argument gemm_arg{ - std::array{p_a_grid}, // p_as_grid - std::array{p_b_grid}, // p_bs_grid - arg.p_ds_grid_, - p_e_grid, - GemmM, - GemmN, - GemmK, - // No need to set strides, we pass descs to kernel - {I0}, // StrideAs - {I0}, // StrideBs - {}, // StrideDs - I0, // StrideE - I1, // kbatch - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_}; - // TODO: No is_reduce argument, defaults to false. - const auto Run = [&](const auto& kernel) { - // TODO: Rotating mem wrapper has an issue with the new gridwise arg. Not doing for - // now. - if(stream_config.flush_cache) + // TODO: To implement rotating mem wrapper for this device struct we need to use + // RotatingMemWrapperMultiABD and carefully consider what to do with the multiple A, + // B and D tensor sizes, as well as consider Ctranspose, (merge)groups, split_n + // and split_k. It might make more sense to do this after implementing all this + // functionality. + if(stream_config.flush_cache) {} + + if constexpr(CTranspose) { - // typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; - // ck::utility::RotatingMemWrapper - // rotating_mem( - // gemm_arg_, - // stream_config.rotating_count, - // gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), - // gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); - // rotating_mem.Print(); + printf("Got Gemm MNK %d %d %d\n", GemmM, GemmN, GemmK); + typename GridwiseGemmCTranspose::Argument gemm_arg{ + std::array{p_b_grid}, // p_bs_grid + std::array{p_a_grid}, // p_as_grid + arg.p_ds_grid_, + p_e_grid, - // auto run_flush_cache = [&]() { - // // flush icache - // ck::utility::flush_icache(); - // // rotating mem - // rotating_mem.Next(); - // }; + GemmN, + GemmM, - // ave_time += ck::utility::launch_and_time_kernel_with_preprocess( - // stream_config, - // run_flush_cache, - // kernel, - // dim3(gdx, gdy, gdz), - // dim3(BlockSize), - // 0, - // gemm_arg_, - // arg.a_grid_desc_ak0_m_ak1_, - // arg.b_grid_desc_bk0_n_bk1_, - // arg.ds_grid_desc_m_n_, - // arg.e_grid_desc_m_n_, - // arg.compute_ptr_offset_of_groups_, - // arg.compute_ptr_offset_of_n_, - // KPerBlock); // TODO: splitK consideration (num_k_per_block) - - printf("\n\nAttempted to use rotating mem wrapper, not supported!\n\n"); + GemmK, + // No need to set strides, we pass descs to kernel + {I0}, // StrideAs + {I0}, // StrideBs + {}, // StrideDs + I0, // StrideE + I1, // kbatch + arg.b_element_op_, + arg.a_element_op_, + arg.cde_element_op_}; + // TODO: No is_reduce argument, defaults to false. ave_time += launch_and_time_kernel( stream_config, @@ -1072,8 +1139,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_ak0_m_ak1_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1082,6 +1149,25 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } else { + typename GridwiseGemm::Argument gemm_arg{ + std::array{p_a_grid}, // p_as_grid + std::array{p_b_grid}, // p_bs_grid + arg.p_ds_grid_, + p_e_grid, + GemmM, + GemmN, + GemmK, + // No need to set strides, we pass descs to kernel + {I0}, // StrideAs + {I0}, // StrideBs + {}, // StrideDs + I0, // StrideE + I1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + // TODO: No is_reduce argument, defaults to false. + ave_time += launch_and_time_kernel( stream_config, kernel, @@ -1106,242 +1192,42 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffset, - true, // HasMainKBlockLoop - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - // TailNumber TailNum = TailNumber::Full - Run(kernel); + if constexpr(CTranspose) + { + const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< + GridwiseGemmCTranspose, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffset, + true, // HasMainKBlockLoop + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + // TailNumber TailNum = TailNumber::Full + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffset, + true, // HasMainKBlockLoop + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + // TailNumber TailNum = TailNumber::Full + Run(kernel); + } } else { // TODO: check this in arg checker? printf("Unsupported pipeline version!\n"); } - // // Tail number could be One to Seven - // else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) - // { - // const auto kernel = - // kernel_grouped_conv_fwd_xdl_cshuffle_v3; - // Run(kernel); - // } - // else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - // TailNumber::Full) - // { - // const auto kernel = - // kernel_grouped_conv_fwd_xdl_cshuffle_v3; - // Run(kernel); - // } - - // if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) - // { - // const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - // GridwiseGemm, - // ComputePtrOffset, - // DeviceOp::AGridDesc_AK0_M_AK1, - // DeviceOp::BGridDesc_BK0_N_BK1, - // DeviceOp::DsGridDesc_M_N, - // DeviceOp::EGridDesc_M_N, - // true, - // InMemoryDataOperationEnum::Set, - // minimum_occupancy, - // TailNumber::Two>; - // Run(kernel); - // } - // } - - // if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - // TailNumber::Three) - // { - // const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - // GridwiseGemm, - // ComputePtrOffset, - // DeviceOp::AGridDesc_AK0_M_AK1, - // DeviceOp::BGridDesc_BK0_N_BK1, - // DeviceOp::DsGridDesc_M_N, - // DeviceOp::EGridDesc_M_N, - // true, - // InMemoryDataOperationEnum::Set, - // minimum_occupancy, - // TailNumber::Three>; - // Run(kernel); - // } - // } - - // if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) - // { - // const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - // GridwiseGemm, - // ComputePtrOffset, - // DeviceOp::AGridDesc_AK0_M_AK1, - // DeviceOp::BGridDesc_BK0_N_BK1, - // DeviceOp::DsGridDesc_M_N, - // DeviceOp::EGridDesc_M_N, - // true, - // InMemoryDataOperationEnum::Set, - // minimum_occupancy, - // TailNumber::Four>; - // Run(kernel); - // } - // } - - // if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) - // { - // const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - // GridwiseGemm, - // ComputePtrOffset, - // DeviceOp::AGridDesc_AK0_M_AK1, - // DeviceOp::BGridDesc_BK0_N_BK1, - // DeviceOp::DsGridDesc_M_N, - // DeviceOp::EGridDesc_M_N, - // true, - // InMemoryDataOperationEnum::Set, - // minimum_occupancy, - // TailNumber::Five>; - // Run(kernel); - // } - // } - - // if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) - // { - // const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - // GridwiseGemm, - // ComputePtrOffset, - // DeviceOp::AGridDesc_AK0_M_AK1, - // DeviceOp::BGridDesc_BK0_N_BK1, - // DeviceOp::DsGridDesc_M_N, - // DeviceOp::EGridDesc_M_N, - // true, - // InMemoryDataOperationEnum::Set, - // minimum_occupancy, - // TailNumber::Six>; - // Run(kernel); - // } - // } - - // if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - // TailNumber::Seven) - // { - // const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - // GridwiseGemm, - // ComputePtrOffset, - // DeviceOp::AGridDesc_AK0_M_AK1, - // DeviceOp::BGridDesc_BK0_N_BK1, - // DeviceOp::DsGridDesc_M_N, - // DeviceOp::EGridDesc_M_N, - // true, - // InMemoryDataOperationEnum::Set, - // minimum_occupancy, - // TailNumber::Seven>; - // Run(kernel); - // } - // } - // } - // // Tail number could be Odd or Even - // else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - // { - // const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< - // GridwiseGemm, - // ComputePtrOffset, - // DeviceOp::AGridDesc_AK0_M_AK1, - // DeviceOp::BGridDesc_BK0_N_BK1, - // DeviceOp::DsGridDesc_M_N, - // DeviceOp::EGridDesc_M_N, - // true, - // InMemoryDataOperationEnum::Set, - // minimum_occupancy, - // TailNumber::Odd>; - // Run(kernel); - // } - // else - // { - // const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< - // GridwiseGemm, - // ComputePtrOffset, - // DeviceOp::AGridDesc_AK0_M_AK1, - // DeviceOp::BGridDesc_BK0_N_BK1, - // DeviceOp::DsGridDesc_M_N, - // DeviceOp::EGridDesc_M_N, - // true, - // InMemoryDataOperationEnum::Set, - // minimum_occupancy, - // TailNumber::Even>; - // Run(kernel); - // } - // } - // else - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - // { - // const auto kernel = - // kernel_grouped_conv_fwd_xdl_cshuffle_v3; - // Run(kernel); - // } - // else - // { - // const auto kernel = - // kernel_grouped_conv_fwd_xdl_cshuffle_v3; - // Run(kernel); - // } - // } } // has_main_k_block_loop else @@ -1350,18 +1236,36 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 // Tail number always 1 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffset, - false, // HasMainKBlockLoop - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - // TailNumber TailNum = TailNumber::Full - Run(kernel); + if constexpr(CTranspose) + { + const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< + GridwiseGemmCTranspose, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffset, + false, // HasMainKBlockLoop + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + // TailNumber TailNum = TailNumber::Full + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffset, + false, // HasMainKBlockLoop + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + // TailNumber TailNum = TailNumber::Full + Run(kernel); + } } else { @@ -1381,8 +1285,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { // At least transpose A from NGCHW to NHWGC, and if necessary transpose B from GKCYX // to GKYXC. - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { printf("\033[32mPerforming transpose forward\033[0m\n"); const index_t a_grid_size = @@ -1438,8 +1341,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 avg_time += RunGemm(arg, stream_config); // Transpose result back to NGCHW - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) + if constexpr(NeedTransposeKernel) { printf("\033[32mPerforming transpose back\033[0m\n"); const index_t grid_size = @@ -1501,6 +1403,12 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const index_t G = arg.b_g_k_c_xs_lengths_[I0]; const index_t K = arg.b_g_k_c_xs_lengths_[I1]; const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + // Move this to runtime check to align Conv instances // with Conv Multiple D instances if constexpr(isMultiABD) @@ -1598,7 +1506,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v) + NeedTransposeKernel) { // TODO: This check originally said "ABlockTransferSrcVectorDim == 2", basically // blocking all instances with a value of 1. I've tried some though and they work just @@ -1616,6 +1524,23 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 return false; } } + else if constexpr(is_same_v || is_same_v) + { + static_assert(NeedTransposeKernel == false); + static_assert(NumGroupsToMerge == 1); + + if constexpr(ABlockTransferSrcScalarPerVector != 1) + { + if(ABlockTransferSrcVectorDim != 1) + { + return false; + } + if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + } else { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) @@ -1658,10 +1583,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 return false; } - if constexpr(is_NGCHW_NGKHW() || // TODO: Removed weight layout - // check. - is_NGCDHW_NGKDHW()) // TODO: Removed weight layout - // check. + if constexpr(NeedTransposeKernel) { if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { @@ -1687,11 +1609,6 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 return false; } - const index_t input_spatial_acum = ck::accumulate_n( - arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); - const index_t output_spatial_acum = ck::accumulate_n( - arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); - if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) @@ -1793,24 +1710,48 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{{nullptr}, - {nullptr}, - {}, - nullptr, - GemmM, - GemmN, - GemmK, - {I0}, - {I0}, - {}, - I0, - I1 /*KBatch*/, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_}; - // TODO: No is_reduce argument, defaults to false. + if constexpr(CTranspose) + { + typename GridwiseGemmCTranspose::Argument gemm_arg{{nullptr}, + {nullptr}, + {}, + nullptr, + GemmN, + GemmM, + GemmK, + {I0}, + {I0}, + {}, + I0, + I1 /*KBatch*/, + arg.b_element_op_, + arg.a_element_op_, + arg.cde_element_op_}; + // TODO: No is_reduce argument, defaults to false. - return GridwiseGemm::CheckValidity(gemm_arg); + return GridwiseGemmCTranspose::CheckValidity(gemm_arg); + } + else + { + typename GridwiseGemmCTranspose::Argument gemm_arg{{nullptr}, + {nullptr}, + {}, + nullptr, + GemmM, + GemmN, + GemmK, + {I0}, + {I0}, + {}, + I0, + I1 /*KBatch*/, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + // TODO: No is_reduce argument, defaults to false. + + return GridwiseGemmCTranspose::CheckValidity(gemm_arg); + } } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 3eb57ccda3..7532ccd7a1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -704,6 +704,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 const long_index_t a_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx)); const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); @@ -717,7 +719,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 BsGridPointer p_bs_grid_; static_for<0, NumBTensor, 1>{}([&](auto i) { using BDataType_ = remove_cvref_t>; - p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]) + b_batch_offset; + p_bs_grid_(i) = + static_cast(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset; }); DsGridPointer p_ds_grid_grp;