From e6d1781f20417d20f0c17cc6440dd042b31aef33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Fri, 17 Apr 2026 06:17:30 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6421 (commit 05b0753) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [MIOpen][CK] Fix bwd weight conv test failures by disabling one block-GEMM V5 instance for 3D convs (#6421) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Due to compiler version update, there are test failures in the test target `test_grouped_convnd_bwd_weight` when running on `gfx90a`. There are four failing tests for FP16/BF16 that arise from a single kernel instance. As the problem is in the current develop branch, the test failures are blocking any PR merges into develop. An example of a failed CI runs is here: [http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/develop/558/pipeline/](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/develop/558/pipeline/). The underlying compiler problem is potentially the same as described in #6342 as the tests are passing for clang compiler version 20.0 and failing for clang compiler version 22.0. First attempt to fix this problem had to be reverted in #6400 because it broke MIOpen internal DB sync tests. ## Technical Details The root cause for the test failures are the block-GEMM V5 instances of `DeviceGroupedConvBwdWeight_Xdl_CShuffleV3` that have large tile size. The V5 pipeline uses double register buffer that in combination with large tile size causes high register pressure. The latest version of compiler handles the register spillage incorrectly for `gfx90a`, which cause the kernel to output incorrect results. The BF16/FP16 instances of `DeviceGroupedConvBwdWeight_Xdl_CShuffleV3` that do not use direct load for are divided into two groups - Base instances - Instances that result into high register usage (currently only one instance - one that causes the test failures). This division allows to disable only the V5 block-GEMM flavor of `DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<64, 128, 32, 32, Default, 8, 4, 1, 8, 8, 8, 8, 1, 1, 2>` for 3D convolutions on `gfx90a`. The selective disabling leaves the set of instances for 1D and 2D convolutions unaffected, and removes at runtime two V5 block-GEMM instances (`ConvBwdWeightDefault` and `ConvBwdWeightFilter1x1Stride1Pad0`) per data type (FP16/BF16) when the device is `gfx90a`. Because MIOpen uses CK's type string (provided by method `GetTypeString`) to identify the instances, the DB sync tests are expected to unaffected since there are still the V2 block-GEMM instances that result in the same type string (`DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<64, 128, 32, 32, Default, 8, 4, 1, 8, 8, 8, 8, 1, 1, 2>`). This expectation needs to be verified by running the MIOpen DB sync tests that are not part of the normal CK PR build. ## Test Plan Running all CI tests + the MIOpen internal DB sync tests is sufficient to verify the correctness of the code changes. ## Test Result Verified locally that the previously failing tests `TestGroupedConvndBwdWeight3d/4.Test3D` and `TestGroupedConvndBwdWeight3d/4.Test3D` have instance counts - 231 on `gfx90a` - 233 on `gfx942` and are currently passing. This confirms the expectation that two instances per data type should be disabled on `gfx90a`. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Ville Pietilä <> --- include/ck/host_utility/device_prop.hpp | 2 + ...rouped_conv_bwd_weight_v3_xdl_instance.hpp | 68 +++++++++++++++++-- ...xc_ndhwgk_bf16_default_pipev5_instance.cpp | 33 ++++++--- ...kzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp | 35 +++++++--- ...yxc_ndhwgk_f16_default_pipev5_instance.cpp | 34 +++++++--- ...gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp | 34 +++++++--- 6 files changed, 166 insertions(+), 40 deletions(-) diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 97852531a9..e20deb11ea 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -52,6 +52,8 @@ inline std::string get_device_name() } } +inline bool is_gfx90a() { return ck::get_device_name() == "gfx90a"; } + inline bool is_gfx12_supported() { return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp index 3a3dc156ec..c3834c7d17 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp @@ -77,6 +77,30 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances = std:: // clang-format on >; +// Problematic instance on gfx90a due to register splillage in block-GEMM v5 pipeline. +// Compiler doesn't handle correctly the register presure on gfx90a, which results in failing +// accuracy tests fail for 3D bwd weight conv. The problem occurs at least for compiler version +// 22.0.0git (https://github.com/ROCm/llvm-project.git +// 2de9eb6063dd56b109cf139a75550b7b06808273+PATCHED:9a6ac45c97a1e511db838c5b46257324d2de1780) +// Older compilers from the 20.0 family produce correct results. +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_high_reg_usage_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Dt, Dt, Dt, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> + // clang-format on + >; + template -using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple< +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_base_instances = std::tuple< // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| @@ -95,12 +119,37 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> // clang-format on >; +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = decltype(::std::tuple_cat( + ::std::declval< + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_base_instances>(), + ::std::declval>())); + template -using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances = std::tuple< +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_base_instances = std::tuple< // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| @@ -168,12 +217,23 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances = std::tupl DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> //clang-format on >; +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances = + decltype(::std::tuple_cat( + ::std::declval>(), + ::std::declval>())); + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp index b9606a3e6c..1091825fd6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp @@ -22,15 +22,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_ PassThrough, PassThrough>>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightDefault, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v5>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_base_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); + if(!is_gfx90a()) + { + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_high_reg_usage_instances< + 3, + ck::bhalf_t, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp index fc562203a0..93d84ede5e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,15 +23,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pip PassThrough, PassThrough>>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v5>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_base_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); + + if(!is_gfx90a()) + { + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_high_reg_usage_instances< + 3, + ck::bhalf_t, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp index 7294509406..d0cfe7ae98 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,15 +23,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_p PassThrough, PassThrough>>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightDefault, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v5>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_base_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); + if(!is_gfx90a()) + { + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_high_reg_usage_instances< + 3, + ck::half_t, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp index c53347c293..98dd79e484 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp @@ -3,6 +3,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { @@ -22,15 +23,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipe PassThrough, PassThrough>>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v5>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_base_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); + if(!is_gfx90a()) + { + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_high_reg_usage_instances< + 3, + ck::half_t, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); + } } } // namespace instance