From a8aebb7a8efbd9860487a4bc563706cf7a71f988 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Mon, 22 Dec 2025 15:57:45 +0100 Subject: [PATCH] Post-merge cleanup for WMMA grouped conv fwd (#3468) * remove duplicate aliases * Split scaleadd_ab instances for WMMA grouped conv fwd * removed big shape from the test --- ...d_wmma_cshufflev3_scaleadd_ab_instance.hpp | 140 ++++++++++++++++-- ...rouped_convolution_forward_scaleadd_ab.hpp | 98 +++++++++++- .../CMakeLists.txt | 10 +- ...wgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp} | 32 ++-- ...hwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp | 52 +++++++ ...hwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp | 52 +++++++ ...hwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp | 52 +++++++ ...hwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp} | 32 ++-- ...dhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp | 52 +++++++ ...dhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp | 52 +++++++ ...dhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp | 52 +++++++ .../src/profile_grouped_conv_fwd_bilinear.cpp | 4 - .../test_grouped_convnd_fwd_scaleadd_ab.cpp | 4 - 13 files changed, 570 insertions(+), 62 deletions(-) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/{device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp => device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp} (85%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/{device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp => device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp} (86%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp index d97cd6f04c..00c3d44dde 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp @@ -40,7 +40,7 @@ template -using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances = +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part1 = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | @@ -57,30 +57,84 @@ using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances = DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part2 = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part3 = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part4 = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> #endif @@ -92,7 +146,7 @@ template -using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances = +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part1 = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | @@ -109,30 +163,84 @@ using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances = DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part2 = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part3 = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part4 = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp index c651aab2c9..d35f159eb5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -91,7 +91,46 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_ins #ifdef CK_USE_WMMA #ifdef CK_ENABLE_BF16 // grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK -void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4( std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4( std::vector> && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances( + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4( op_ptrs); } #endif @@ -227,7 +311,13 @@ struct DeviceOperationInstanceFactory> && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4( op_ptrs); } #endif diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt index aa3dd0af12..b058e0fa9d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -9,8 +9,14 @@ set(GROUPED_CONV3D_FWD_SCALEADD_AB xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp # WMMA CSHUFFLE V3 - wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp ) add_instance_library(device_grouped_conv3d_fwd_scaleadd_ab_instance ${GROUPED_CONV3D_FWD_SCALEADD_AB}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp similarity index 85% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp index f1a711de5e..8388bb9c21 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( std::vector{}); + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part1<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1P0>{}); + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part1<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1S1P0>{}); + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part1<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp new file mode 100644 index 0000000000..f0950e8602 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part2<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part2<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part2<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp new file mode 100644 index 0000000000..4b0792527c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part3<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part3<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part3<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp new file mode 100644 index 0000000000..586eba20b5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part4<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part4<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part4<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp similarity index 86% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp index 5aa527d829..25dcbab0fa 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances( +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1( std::vector{}); + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part1<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1P0>{}); + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part1<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1S1P0>{}); + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part1<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp new file mode 100644 index 0000000000..75ba36eed2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part2<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part2<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part2<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp new file mode 100644 index 0000000000..a9fe162e72 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part3<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part3<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part3<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp new file mode 100644 index 0000000000..364a1a90e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part4<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part4<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part4<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/profile_grouped_conv_fwd_bilinear.cpp b/profiler/src/profile_grouped_conv_fwd_bilinear.cpp index d4490abe7e..906b625479 100644 --- a/profiler/src/profile_grouped_conv_fwd_bilinear.cpp +++ b/profiler/src/profile_grouped_conv_fwd_bilinear.cpp @@ -108,10 +108,6 @@ int grouped_conv_fwd_bilinear(int argc, char* argv[]) using NDHWGC = ck::tensor_layout::convolution::NDHWGC; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; - using GKZYXC = ck::tensor_layout::convolution::GKZYXC; - using NDHWGC = ck::tensor_layout::convolution::NDHWGC; - using NDHWGK = ck::tensor_layout::convolution::NDHWGK; - constexpr auto I3 = ck::Number<3>{}; auto profile = [&](auto num_dim_spatial_tmp, diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp index ab7a28a388..199a50f0fd 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp @@ -368,10 +368,6 @@ TYPED_TEST(TestGroupedConvndFwdScaleaddAB3d, Test3D) { this->conv_params.clear(); - // Client example 24. This one takes quite long. - this->conv_params.push_back( - {3, 32, 64, 32, 64, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); - // Generic problems, same set as for vanilla, clamp, and (gk) bias clamp tests. this->conv_params.push_back( {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});