From ac28f1b016bb35b1d7f031b4a76d14a01b79ad5b Mon Sep 17 00:00:00 2001 From: Kiefer van Teutem <50830967+krithalith@users.noreply.github.com> Date: Mon, 29 Dec 2025 12:58:29 +0100 Subject: [PATCH] Replace grouped conv bwd wei wmmaV3 bilin/scale bf16f32bf16 support with bf16bf16bf16 (#3470) * Replace grouped convolution bwd weight wmma v3 bilinear and scale bf16f32bf16 support with bf16bf16bf16 support. Update tests. * Tentative fix for bwd weight bilinear bf16bf16bf16, seems like the bilinear elementwise overload for this case (bf16, f32 accu, bf16) was wrong. [ROCm/composable_kernel commit: 88ae4455806efe2019bb0403606f7c4a1e3d9c3a] --- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 1 - .../element/binary_element_wise_operation.hpp | 4 +-- ...conv_bwd_weight_wmma_bilinear_instance.hpp | 32 +++++++++---------- ...ed_conv_bwd_weight_wmma_scale_instance.hpp | 24 +++++++------- ...d_convolution_backward_weight_bilinear.hpp | 11 ++++--- ...uped_convolution_backward_weight_scale.hpp | 9 +++--- ...ear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 6 ++-- ...ale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 4 +-- ...est_grouped_convnd_bwd_weight_bilinear.cpp | 1 + .../test_grouped_convnd_bwd_weight_scale.cpp | 1 + 10 files changed, 47 insertions(+), 46 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 86e8defb83..ba540077ca 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -746,7 +746,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); AccDataType* p_e_grid = type_convert(arg.p_workspace_); - ; // Convolution kernel dispatch typename GridwiseGemm::Argument gemm_arg{ diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 5da2dbc567..ed95de3a8b 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -348,9 +348,7 @@ struct Bilinear __host__ __device__ constexpr void operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const { - const float x1_tmp = ck::type_convert(x1); - const float y_tmp = alpha_ * x0 + beta_ * x1_tmp; - y = y_tmp; + y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); }; template <> diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp index f254628f73..85b4e9b056 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp @@ -70,24 +70,24 @@ template using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_bilinear_instances = std::tuple< // clang-format off - //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| - //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | - //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | - //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | + //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // other instances - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, BF16, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp index e893c92d1d..5f27b14450 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp @@ -74,19 +74,19 @@ using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_scale_instances = std:: //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // other instances - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, BF16, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp index 48a43e59ad..6fde57b44e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp @@ -34,16 +34,16 @@ void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16 PassThrough>>>& instances); #endif #ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( +void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances( std::vector, BF16, - F32, BF16, - Tuple, + BF16, + Tuple, PassThrough, Bilinear, PassThrough>>>& instances); @@ -197,12 +197,13 @@ struct DeviceOperationInstanceFactory< } #endif #ifdef CK_ENABLE_BF16 - if constexpr(is_same_v && is_same_v && + if constexpr(is_same_v && + is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances( op_ptrs); } #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index e6a64e3716..c24243943a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -35,14 +35,14 @@ void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_in #endif #ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( +void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances( std::vector, BF16, - F32, + BF16, BF16, Tuple<>, PassThrough, @@ -197,12 +197,13 @@ struct DeviceOperationInstanceFactory< } #endif #ifdef CK_ENABLE_BF16 - if constexpr(is_same_v && is_same_v && + if constexpr(is_same_v && + is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances( op_ptrs); } #endif diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 06398729af..f13bf5c79c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -10,16 +10,16 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( +void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances( std::vector, BF16, - F32, BF16, - Tuple, + BF16, + Tuple, PassThrough, Bilinear, PassThrough>>>& instances) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 32aeb2f19f..e53f0b412e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -10,14 +10,14 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( +void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_bf16_bf16_instances( std::vector, BF16, - F32, + BF16, BF16, Tuple<>, PassThrough, diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp index 08f509a7e5..ff025e2dba 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp @@ -296,6 +296,7 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight using KernelTypes3d = ::testing::Types>, std::tuple>, + std::tuple>, std::tuple>>; TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d); diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp index 5600ab5c0a..dba2fbd5d4 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp @@ -269,6 +269,7 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight using KernelTypes3d = ::testing::Types>, std::tuple>, + std::tuple>, std::tuple>>; TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d);