From 5d4107862b1f8686950d4e4ea0bec0f993a47e72 Mon Sep 17 00:00:00 2001 From: JP-Fernando <103817231+JP-Fernando@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:05:44 +0100 Subject: [PATCH] [CK] Add BF16^3 support to grouped conv bwd weight: bilinear and scale (#4591) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Until now, XDL grouped conv bwd weight for bilinear and scale only supported bf16f32bf16. Therefore, bf16bf16bf16 support should be added. ## Technical Details Instances were added to the relevant files in `library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/` folder. In addition, `add()` functions were included in new files in `library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/` and `library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/` folders. The new .cpp files were also included in the `CMakeFiles.txt` files of both folders. ## Test Plan Execute `grouped_convnd_bwd_weight` tests to check execution on different architectures. The tests for bilinear and scale already include the tuple `std::tuple>`, so in principle, there is nothing to modify in the tests themselves. ## Test Result `gfx1201`: Tests passed. `gfx1100`: Tests passed. `gfx90a`: Tests passed. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Fernando Jiménez --- ..._conv_bwd_weight_xdl_bilinear_instance.hpp | 18 ++++++- ...ped_conv_bwd_weight_xdl_scale_instance.hpp | 18 ++++++- ...d_convolution_backward_weight_bilinear.hpp | 24 +++++++++ ...uped_convolution_backward_weight_scale.hpp | 23 +++++++++ .../CMakeLists.txt | 4 +- ...c_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp | 50 ++++++++++++++++++ ...ear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 18 ++----- .../CMakeLists.txt | 4 +- ...c_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp | 51 +++++++++++++++++++ ...ale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 15 ++---- 10 files changed, 195 insertions(+), 30 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp index 75655dda8d..87b47e54f1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp @@ -149,7 +149,7 @@ template -using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_bilinear_instances = std::tuple< +using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_bilinear_instances = std::tuple< // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| DsData| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| //#########################################| Dim| | | | Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| @@ -177,6 +177,22 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_bilinear_instances = std // clang-format on >; +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_bilinear_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| DsData| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template -using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_scale_instances = std::tuple< +using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_scale_instances = std::tuple< // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| DsData| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| //#########################################| Dim| | | | Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| @@ -176,6 +176,22 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_scale_instances = std::t // clang-format on >; +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_scale_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| DsData| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template >>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple, + PassThrough, + Bilinear, + PassThrough>>>& instances); #endif #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( @@ -255,6 +269,16 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( op_ptrs); } + + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 if constexpr(is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index c24243943a..0ad9c38714 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -66,6 +66,19 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f3 PassThrough, Scale, PassThrough>>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + BF16, + BF16, + BF16, + Tuple<>, + PassThrough, + Scale, + PassThrough>>>& instances); #endif #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( @@ -255,6 +268,16 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( op_ptrs); } + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } + #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 if constexpr(is_same_v && is_same_v && diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt index 455f14d9f4..9cb0ccda0d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt @@ -6,7 +6,9 @@ set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) + xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp + ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_BWD_WEIGHT_BILINEAR diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp new file mode 100644 index 0000000000..85d972d64a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector, + BF16, + F32, + BF16, + Tuple, + PassThrough, + Bilinear, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_bilinear_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_bilinear_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index f0582697b9..1b019a8cf7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -9,22 +9,21 @@ namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( +void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector, BF16, - F32, BF16, - Tuple, + BF16, + Tuple, PassThrough, Bilinear, PassThrough>>>& instances) { - // 1. Default + // Default bwd weight bilinear add_device_operation_instances( instances, device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_bilinear_instances< @@ -33,15 +32,6 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16 GKZYXC, NDHWGK, ConvBwdWeightDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_bilinear_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt index b7fefdc94f..85c99dc422 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt @@ -6,7 +6,9 @@ set(GROUPED_CONV3D_BWD_WEIGHT_SCALE xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) + xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp + ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_BWD_WEIGHT_SCALE diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp new file mode 100644 index 0000000000..41f051479b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector, + BF16, + F32, + BF16, + Tuple<>, + PassThrough, + Scale, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_scale_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_scale_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index c1bc8c3a1c..a08c4f4054 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -10,21 +10,21 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( +void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector, BF16, - F32, + BF16, BF16, Tuple<>, PassThrough, Scale, PassThrough>>>& instances) { - // 1. Default + // Default conv bwd weight add_device_operation_instances( instances, device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_scale_instances<3, @@ -32,15 +32,6 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f3 GKZYXC, NDHWGK, ConvBwdWeightDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_scale_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance