diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index f662ff834f..5a73d86ab6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -29,6 +29,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" + + +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" #include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" @@ -39,8 +44,8 @@ namespace tensor_operation { namespace device { template ::type; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; - GridwiseGemm::template Run struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 @@ -164,6 +181,15 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB> { + +#if defined USE_WAVE + + static_assert(UseThreadTileTransfer==false && + (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 + ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" + ); +#endif + using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; using ADataType = OutDataType; @@ -275,6 +301,20 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 params, batch); } + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } using ABCGridDescs = decltype(GetABCGridDesc()); @@ -282,6 +322,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_K0_M_K1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_K0_N_K1{})); + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, @@ -334,7 +377,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 false, // permuteA false, // permuteB false, // IsBPreShuffled - true>; // ForceThreadTileTransfer + UseThreadTileTransfer>; // ForceThreadTileTransfer static constexpr auto MakeElementwiseInputSequence() { @@ -592,8 +635,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 p_b_grid_{p_in_grid}, p_ds_grid_{}, p_e_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_kbatch_m_k_{}, + b_grid_desc_kbatch_n_k_{}, ce_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -687,8 +730,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_k_c_xs_strides[i][0]; }); - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]); + b_grid_desc_kbatch_n_k_ =transform_k0_m_k1_to_m_k(descs[I1]); ce_grid_desc_m_n_ = descs[I2]; ds_grid_descs_tuple_ = @@ -707,8 +750,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 index_t{1}, std::multiplies<>{}); - const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -726,8 +769,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const BDataType* p_b_grid_; DsGridPointerTuple p_ds_grid_; EDataType* p_e_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_M_K a_grid_desc_kbatch_m_k_; + BGridDesc_N_K b_grid_desc_kbatch_n_k_; CGridDesc_M_N ce_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; DsGridDesc_M_N ds_grid_descs_tuple_; @@ -784,10 +827,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { float ave_time = 0; - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + AccDataType* p_e_grid = type_convert(arg.p_workspace_); @@ -817,8 +860,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - const auto num_k_per_block = - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_)); const auto clear_workspace = [&]() { hip_check_error(hipMemsetAsync( @@ -831,11 +873,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; std::array size_as_buffers; - size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * + size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() * sizeof(ADataType) / GridwiseGemm::APackedSize; std::array size_bs_buffers; - size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * + size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() * sizeof(BDataType) / GridwiseGemm::BPackedSize; std::array size_ds_buffers; @@ -865,8 +907,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -881,8 +923,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_kbatch_m_k_, + arg.b_grid_desc_kbatch_n_k_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block); @@ -903,8 +945,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -918,8 +960,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -944,8 +986,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -959,8 +1001,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< GridwiseGemm, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, @@ -1030,10 +1072,17 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + + const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1); + typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid std::array{nullptr}, // p_bs_grid diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index eb516378b4..ac23bcd325 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -181,11 +181,18 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static_assert(is_same_v); using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffleV3; - + using ADataType = OutDataType; using BDataType = InDataType; using CDataType = WeiDataType; +// // static const auto F1S1 = ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; +// #if defined USE_WAVE +// static_assert(UseThreadTileTransfer==false && +// (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0 +// ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer" +// ); +// #endif // If NGCHW then ADataType must be equal to BDataType static_assert(!(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) || diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp index fa613bd836..4c0f9e5276 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_wave_transfer_instances.hpp @@ -43,11 +43,11 @@ template using device_grouped_conv_bwd_data_wmma_cshufflev3_bf16_wave_transfer_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // generic instance + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, @@ -63,12 +63,11 @@ template using device_grouped_conv_bwd_data_wmma_cshufflev3_f16_wave_transfer_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - - // generic instance + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp new file mode 100644 index 0000000000..cd43e558a0 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp @@ -0,0 +1,93 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#define USE_WAVE +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + + + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances = std::tuple< + // clang-format off + //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | + //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances = std::tuple< + // clang-format off + //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | + //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>, + + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false> + + //clang-format on + >; + + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index c07dc71ac5..7ecb626212 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -12,6 +12,11 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#if defined USE_WAVE +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" + +#endif + #ifdef DL_KERNELS #include "grouped_convolution_backward_weight_dl.inc" #endif @@ -957,7 +962,73 @@ struct DeviceOperationInstanceFactory +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD< + NumDimSpatial, + InLayout, + WeiLayout, + OutLayout, + DsLayout, + InDataType, + WeiDataType, + OutDataType, + DsDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ComputeTypeA, + ComputeTypeB>> +{ + using DeviceOp = + DeviceGroupedConvBwdWeightMultipleD; + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_BF16 + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + op_ptrs); +#endif +#ifdef CK_ENABLE_FP16 + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( + op_ptrs); +#endif + return op_ptrs; + + } + + }; +#endif } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc index 06247019f1..09977754a3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -114,6 +114,72 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf PassThrough>>>& instances); #endif +#if defined USE_WAVE +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +#endif + +#ifdef CK_ENABLE_FP16 + +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +#endif +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 268835d5bf..59a30d4ea9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -79,6 +79,8 @@ list(APPEND GROUPED_CONV2D_BWD_WEIGHT wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp ) add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp new file mode 100644 index 0000000000..9b99f910bd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp new file mode 100644 index 0000000000..4f84fce872 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances( + std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index b246b87178..e7d8403812 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -73,6 +73,8 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp new file mode 100644 index 0000000000..0fb351d88d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp new file mode 100644 index 0000000000..6f4db11b4d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances( + std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck