From 92e8a1035ea87443697acd2a36a85f4083e4c763 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 3 Feb 2026 15:00:11 +0000 Subject: [PATCH] Built fix remove multi D functionality --- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 132 +++++++----------- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 23 ++- ..._weight_v3_wmma_wave_transfer_instance.hpp | 35 +++-- .../grouped_convolution_backward_weight.hpp | 83 ++--------- ...ouped_convolution_backward_weight_wmma.inc | 115 +++++++-------- ...kyxc_nhwgk_bf16_wave_transfer_instance.cpp | 4 +- ...gkyxc_nhwgk_f16_wave_transfer_instance.cpp | 10 +- ...yxc_ndhwgk_bf16_wave_transfer_instance.cpp | 4 +- ...zyxc_ndhwgk_f16_wave_transfer_instance.cpp | 4 +- 9 files changed, 136 insertions(+), 274 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 88cdef7548..f662ff834f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -29,9 +29,6 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" -#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" - #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" @@ -42,8 +39,8 @@ namespace tensor_operation { namespace device { template ::type; - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = EpilogueType{}; + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - const auto a_grid_desc_ak0_m_ak1 = - GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); - - const auto b_grid_desc_bk0_n_bk1 = - GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); - - GridwiseGemm::template Run struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 @@ -178,15 +164,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB> { - -#if defined USE_WAVE_TRANSFER_BWD_WEI - - static_assert(UseThreadTileTransfer == false && - (ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0), - "Only Filter1x1Stride1Pad0is supported for wavetile transfer"); -#endif - using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; using ADataType = OutDataType; @@ -299,29 +276,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 batch); } - template - static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) - { - const auto grid_desc_m_k = transform_tensor_descriptor( - desc_k0_m_k1, - make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), - make_merge_transform( - make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return grid_desc_m_k; - } - using ABCGridDescs = decltype(GetABCGridDesc()); using AGridDesc_K0_M_K1 = remove_cvref_t; using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_K0_M_K1{})); - using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_K0_N_K1{})); - using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, @@ -371,10 +331,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - false, // permuteA - false, // permuteB - false, // IsBPreShuffled - UseThreadTileTransfer>; // ForceThreadTileTransfer + false, // permuteA + false, // permuteB + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer static constexpr auto MakeElementwiseInputSequence() { @@ -632,8 +592,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 p_b_grid_{p_in_grid}, p_ds_grid_{}, p_e_grid_{p_wei_grid}, - a_grid_desc_kbatch_m_k_{}, - b_grid_desc_kbatch_n_k_{}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, ce_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -727,9 +687,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_k_c_xs_strides[i][0]; }); - a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); - b_grid_desc_kbatch_n_k_ = transform_k0_m_k1_to_m_k(descs[I1]); - ce_grid_desc_m_n_ = descs[I2]; + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; ds_grid_descs_tuple_ = MakeDsGridDescriptor_M_N(ds_g_k_c_xs_lengths, ds_g_k_c_xs_strides); @@ -747,8 +707,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 index_t{1}, std::multiplies<>{}); - const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0); - const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -766,8 +726,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const BDataType* p_b_grid_; DsGridPointerTuple p_ds_grid_; EDataType* p_e_grid_; - AGridDesc_M_K a_grid_desc_kbatch_m_k_; - BGridDesc_N_K b_grid_desc_kbatch_n_k_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; CGridDesc_M_N ce_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; DsGridDesc_M_N ds_grid_descs_tuple_; @@ -824,9 +784,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { float ave_time = 0; - const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); - const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); - const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); AccDataType* p_e_grid = type_convert(arg.p_workspace_); @@ -856,8 +817,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( - arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_)); + const auto num_k_per_block = + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; const auto clear_workspace = [&]() { hip_check_error(hipMemsetAsync( @@ -870,11 +831,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; std::array size_as_buffers; - size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() * + size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) / GridwiseGemm::APackedSize; std::array size_bs_buffers; - size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() * + size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) / GridwiseGemm::BPackedSize; std::array size_ds_buffers; @@ -904,8 +865,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_m_k_, - arg.b_grid_desc_kbatch_n_k_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -920,8 +881,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_m_k_, - arg.b_grid_desc_kbatch_n_k_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -942,8 +903,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -957,8 +918,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -983,8 +944,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -998,8 +959,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1069,9 +1030,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { - const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); - const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); - const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index e60f95d9cb..181f67fabf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -184,16 +184,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 using ADataType = OutDataType; using BDataType = InDataType; using CDataType = WeiDataType; - // // static const auto F1S1 = - // ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - // #if defined USE_WAVE - // static_assert(UseThreadTileTransfer==false && - // (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 - // ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" - // ); - // #endif - // If NGCHW then ADataType must be equal to BDataType + #if defined USE_WAVE_TRANSFER_BWD_WEI + + static_assert(UseThreadTileTransfer==false && + (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 + ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" + ); + #endif + // If NGCHW then ADataType must be equal to BDataType static_assert(!(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) || is_same_v); @@ -1094,12 +1093,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp index 4984ac1cec..f74bbbef87 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -3,11 +3,12 @@ #define USE_WAVE_TRANSFER_BWD_WEI #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + namespace ck { namespace tensor_operation { namespace device { @@ -27,6 +28,8 @@ using F8 = ck::f8_t; using BF8 = ck::bf8_t; #endif + + using Empty_Tuple = ck::Tuple<>; template @@ -49,15 +52,13 @@ template using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances = std::tuple< // clang-format off - //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| - //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | - //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | - //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | - // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, - - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> - + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, false>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion, false> + // clang-format on >; @@ -70,14 +71,12 @@ template using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances = std::tuple< // clang-format off - //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| - //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | - //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | - //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | - // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, - - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, false>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion, false> //clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 56f511bc89..9ac16834f0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -12,11 +12,6 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#if defined USE_WAVE_TRANSFER_BWD_WEI -#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" - -#endif - #ifdef DL_KERNELS #include "grouped_convolution_backward_weight_dl.inc" #endif @@ -398,9 +393,6 @@ struct DeviceOperationInstanceFactory -struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD< - NumDimSpatial, - InLayout, - WeiLayout, - OutLayout, - DsLayout, - InDataType, - WeiDataType, - OutDataType, - DsDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ComputeTypeA, - ComputeTypeB>> -{ - using DeviceOp = - DeviceGroupedConvBwdWeightMultipleD; - static auto GetInstances() - { - std::vector> op_ptrs; - -#ifdef CK_ENABLE_BF16 - add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( - op_ptrs); -#endif -#ifdef CK_ENABLE_FP16 - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( - op_ptrs); -#endif - return op_ptrs; - } -}; -#endif } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc index 33f4165c28..ba44bf5b69 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -22,6 +22,18 @@ void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( + std::vector>>& instances); + + void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( std::vector>>& instances); #endif -#if defined USE_WAVE_TRANSFER_BWD_WEI -#ifdef CK_ENABLE_BF16 - -void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( - std::vector, - BF16, - BF16, - BF16, - Tuple<>, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( - std::vector, - BF16, - BF16, - BF16, - Tuple<>, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -#endif - -#ifdef CK_ENABLE_FP16 - -void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( - std::vector, - F16, - F16, - F16, - Tuple<>, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( - std::vector, - F16, - F16, - F16, - Tuple<>, - PassThrough, - PassThrough, - PassThrough>>>& instances); - -#endif -#endif - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp index 36312b814e..126d744437 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp @@ -11,15 +11,13 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( - std::vector, BF16, BF16, BF16, - Tuple<>, PassThrough, PassThrough, PassThrough>>>& instances) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp index 60287b459c..0c41d91105 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp @@ -11,15 +11,13 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( - std::vector, - F16, - F16, - F16, - Tuple<>, + F16, + F16, + F16, PassThrough, PassThrough, PassThrough>>>& instances) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp index eee14e3236..5d6be45bfb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp @@ -11,15 +11,13 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( - std::vector, BF16, BF16, BF16, - Tuple<>, PassThrough, PassThrough, PassThrough>>>& instances) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp index 0eeef111a2..200d150beb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp @@ -11,15 +11,13 @@ namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( - std::vector, F16, F16, F16, - Tuple<>, PassThrough, PassThrough, PassThrough>>>& instances)