From 0b7f0cbbeb565bcf007cb62e7b4be4e46b4084cd Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 1 Aug 2025 14:12:22 +0000 Subject: [PATCH] Add instances for xdl parity (for pipeline v1) --- ...ouped_conv_bwd_weight_v3_wmma_instance.hpp | 36 ++++-- ..._grouped_conv_bwd_weight_wmma_instance.hpp | 51 +++++++-- .../grouped_convolution_backward_weight.hpp | 14 +++ ...ouped_convolution_backward_weight_wmma.inc | 108 ++++++++++++++++-- .../grouped_conv2d_bwd_weight/CMakeLists.txt | 3 + ...hwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp | 10 ++ ...ht_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 48 ++++++++ ..._gkyxc_nhwgk_bf16_pad0_pipev1_instance.cpp | 40 +++++++ ...c_gkyxc_nhwgk_f16_pad0_pipev1_instance.cpp | 39 +++++++ .../grouped_conv3d_bwd_weight/CMakeLists.txt | 6 +- ...c_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp | 10 ++ ...xc_ndhwgk_bf16_default_pipev1_instance.cpp | 39 +++++++ ...kzyxc_ndhwgk_bf16_pad0_pipev1_instance.cpp | 39 +++++++ ...yxc_ndhwgk_f16_default_pipev1_instance.cpp | 39 +++++++ ...gkzyxc_ndhwgk_f16_pad0_pipev1_instance.cpp | 39 +++++++ 15 files changed, 489 insertions(+), 32 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp index 2f3987051c..ad46827cb5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp @@ -49,12 +49,18 @@ template using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | - // generic instance - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 2, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 2, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> // clang-format on >; @@ -67,12 +73,18 @@ template using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | - // generic instance - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 2, 5, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 2, 7, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> //clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp index db9e4721bd..aa63cdd7ba 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp @@ -47,7 +47,7 @@ using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_generic_instances = std: //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | // generic instance - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> // clang-format on >; @@ -64,11 +64,11 @@ using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances = std::tuple< //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | - // generic instance + // // generic instance DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, - // instance for small conv.K - // for fp16 conv.K and conv.C must be divisible by 2 - // since half_t atomic_add require scalar_per_x_vector % 2 == 0 + // // instance for small conv.K + // // for fp16 conv.K and conv.C must be divisible by 2 + // // since half_t atomic_add require scalar_per_x_vector % 2 == 0 DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, @@ -117,7 +117,24 @@ using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_instances = st //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | // generic instance - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + // instance for small conv.K + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> // clang-format on >; @@ -135,7 +152,27 @@ using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances = std::tuple< //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | // generic instance - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + // instance for small conv.K + // for bf16 conv.K and conv.C must be divisible by 2 + // since half_t atomic_add require scalar_per_x_vector % 2 == 0 + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> + // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 8f4ea7d6bc..2d3ac011d4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -848,6 +848,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { + add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instances( op_ptrs); add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instances( @@ -877,6 +879,8 @@ struct DeviceOperationInstanceFactory>>& instances); + void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_default_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -209,6 +233,18 @@ void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_default_pi PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_f16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( std::vector{}); + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 0000000000..b8aed9e562 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..0fca8d8e65 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instances( + std::vector>>& instances) +{ + + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..10dec0f750 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 40ac353587..6f33fa1dda 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -64,9 +64,13 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp - wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instance wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instance wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp index f2cd2a2311..933bcc8720 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp @@ -30,6 +30,16 @@ void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf1 GKZYXC, GNDHWK, ConvBwdWeightDefault>{}); + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_generic_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instance.cpp new file mode 100644 index 0000000000..13f6963be1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..4fb1e0a888 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instance.cpp new file mode 100644 index 0000000000..59769b8843 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..1b6a2de45f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck