From d9298b74e342ca146a041ca221b33a93ae17d5bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 28 Feb 2026 02:24:30 +0100 Subject: [PATCH] [CK] Port non-grouped convolution instances to the grouped kernels (#4875) ## Motivation Port non-grouped convolution instances to the grouped kernels to deprecated older non-grouped implementations. ## Technical Details Add the same instances as non-grouped but using grouped kernel. ## Test Plan test_grouped_convnd_fwd ## Test Result pass ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-724 --- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 14 +++ .../gpu/grouped_convolution_forward.hpp | 4 + .../gpu/grouped_convolution_forward_xdl.inc | 28 ++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 2 + ...xc_nhwgk_nongroup_ported_bf16_instance.cpp | 86 +++++++++++++++++++ ...yxc_nhwgk_nongroup_ported_f16_instance.cpp | 85 ++++++++++++++++++ 6 files changed, 219 insertions(+) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_nongroup_ported_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_nongroup_ported_f16_instance.cpp diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 8efa0e355d..d66679a318 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -376,6 +376,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return false; } + if constexpr(KPerBlock < 16) + { + return false; + } #endif if constexpr(Base::GetSharedMemoryNumberOfByte(get_device_arch()) > @@ -415,6 +419,16 @@ struct GridwiseGemmMultipleD_xdl_cshuffle static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, "KPerBlock must be divisible by AK1Value and BK1Value!"); +#ifndef __HIPCC_RTC__ + if constexpr(KPerBlock < 16) + { + if(ck::is_gfx12_supported() || ck::is_gfx11_supported()) + { + return false; + } + } +#endif + const auto M = a_grid_desc_m_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0); const auto AK = a_grid_desc_m_k.GetLength(I1); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 08e2092c50..b90cd44df0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -251,6 +251,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_nongroup_ported_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( @@ -276,6 +278,8 @@ struct DeviceOperationInstanceFactory && is_same_v) { + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_nongroup_ported_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index 59225787a7..d21393e8b1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -124,6 +124,20 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( // grouped conv2d forward, NHWGC/GKYXC/NHWGK #ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_nongroup_ported_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances); + void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWGC = ck::tensor_layout::convolution::NHWGC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using NHWGK = ck::tensor_layout::convolution::NHWGK; + +using EmptyTuple = Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_nongroup_ported_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 64, 8, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 256, 64, 8, 4, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 128, 64, 8, 4, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, BF16, BF16, F32, BF16, EmptyTuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 64, 64, 8, 4, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_nongroup_ported_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_nongroup_ported_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_nongroup_ported_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_nongroup_ported_f16_instance.cpp new file mode 100644 index 0000000000..c4be6ea6d0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_nongroup_ported_f16_instance.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using NHWGC = ck::tensor_layout::convolution::NHWGC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using NHWGK = ck::tensor_layout::convolution::NHWGK; + +using EmptyTuple = Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_nongroup_ported_instances = std::tuple< + // clang-format off + //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 128, 64, 8, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 256, 256, 64, 8, 4, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 128, 64, 8, 4, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< 2 ,NHWGC, GKYXC, EmptyTuple, NHWGK, F16, F16, F32, F16, EmptyTuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault,GemmMNKPadding, 1, 128, 64, 64, 8, 4, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_nongroup_ported_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_nongroup_ported_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck