From 513b14c5f27d5fb5c0644dac0e4f13fd74fb867c Mon Sep 17 00:00:00 2001 From: ApoorvaKalyani Date: Thu, 22 Jan 2026 09:53:59 +0100 Subject: [PATCH] Grouped conv_fwd_bias_bnorm_clamp instances and tests (#3525) * Added bias_bnorm_clamp instances. * fwd_bias_bnorm_clamp comp instances * fwd_bias_bnorm_mem_inter and mem_intra instances * fwd_bias_bnorm_merged_group_instances * fwd_bias_bnorm_clamp_conv3d_bf16 and f16 instances * Device level changes for fwd_bias_bnorm_clamp * Added the test to the regression test list. * Removed the part 2 and 2x instances * Removed the irrelevant checks in wmma * Refactored the instances to adapt to new device implementation * Updated the reference and include files * enabling tests * Added missing profiler * Added missing instance entry , deleted by mistake * Reduce bias bnorm clamp instances to only a single generic one. * Clean up cmakelists file * clang-format * Change bias bnorm clamp tests to use monotone initialization values to avoid tiny off-integer gemm results on RDNA3 from blowing up. * Renaming some instance lists and add functions to be more standardized. * Commented out non default instances. --------- Co-authored-by: kiefer [ROCm/composable_kernel commit: 8daf6ea3026aebe3481792c03026692631059725] --- ...uped_conv_fwd_wmma_cshufflev3_instance.hpp | 45 +++- ...d_convolution_forward_bias_bnorm_clamp.hpp | 57 +++++ ...rward_bias_bnorm_clamp_wmma_cshufflev3.inc | 78 +++++++ .../CMakeLists.txt | 32 +-- ...fflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 65 ++++++ ...ufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 65 ++++++ .../CMakeLists.txt | 78 +++---- ...fflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 65 ++++++ ...ufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 65 ++++++ ...grouped_conv_fwd_bias_bnorm_clamp_impl.hpp | 35 +-- profiler/src/CMakeLists.txt | 3 + ...file_grouped_conv_fwd_bias_bnorm_clamp.cpp | 202 ++++++++++++++++++ test/CMakeLists.txt | 1 + .../CMakeLists.txt | 15 +- ...st_grouped_convnd_fwd_bias_bnorm_clamp.cpp | 35 +-- ...grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp | 35 ++- 16 files changed, 768 insertions(+), 108 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 profiler/src/profile_grouped_conv_fwd_bias_bnorm_clamp.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp index 61b85dd12c..65ac3b7bc5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp @@ -24,9 +24,10 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddClamp = ck::tensor_operation::element_wise::AddClamp; -using Clamp = ck::tensor_operation::element_wise::Clamp; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -40,6 +41,25 @@ static constexpr auto ConvFwdOddC = static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + template ; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + template && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + } +#endif + } + // layout NDHWGC/GKZYXC/NDHWGK + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif + } +#endif // CK_USE_WMMA + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc new file mode 100644 index 0000000000..e2ad6df07e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index cf1eaf0e12..d089663f37 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# XDL_AND_WMMA_KERNELS set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) @@ -69,15 +69,6 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl ) -set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances - TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in - NUM_SHARDS 3 - SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP - OUTPUT_DIR ${GENERATED_DIR}/xdl -) - # large tensor # NHWGC, GKYXC, NHWGK @@ -89,7 +80,6 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances @@ -108,6 +98,15 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances + TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in + NUM_SHARDS 3 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances @@ -193,7 +192,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances @@ -325,4 +324,11 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) -add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP}) +#WMMA_Cshuffle_v3 +add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance + wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP} +) + + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..4186771720 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); + + // Note: Commented out temporarily , might be used later. + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + // 2, + // NHWGC, + // GKYXC, + // Tuple, + // NHWGK, + // ConvFwd1x1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + // 2, + // NHWGC, + // GKYXC, + // Tuple, + // NHWGK, + // ConvFwd1x1S1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..6d69352e4c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); + + // Note: Commented out temporarily , might be used later. + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + // 2, + // NHWGC, + // GKYXC, + // Tuple, + // NHWGK, + // ConvFwd1x1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + // 2, + // NHWGC, + // GKYXC, + // Tuple, + // NHWGK, + // ConvFwd1x1S1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index 9796c561c0..dc759cbb54 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS -set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP) +# XDL_AND_WMMA_KERNELS +set(GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) @@ -11,7 +11,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -20,7 +20,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -29,7 +29,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -38,7 +38,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -47,7 +47,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.in NUM_SHARDS 4 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -56,7 +56,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.in NUM_SHARDS 4 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -65,7 +65,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -74,7 +74,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) # large tensor @@ -85,16 +85,17 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) @@ -103,7 +104,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.in NUM_SHARDS 2 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) @@ -112,7 +113,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in NUM_SHARDS 2 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) @@ -124,7 +125,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) @@ -133,7 +134,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) @@ -142,7 +143,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) @@ -151,7 +152,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) #mem @@ -162,16 +163,15 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.in NUM_SHARDS 20 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.in NUM_SHARDS 20 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -180,7 +180,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) # NDHWGC, GKZYXC, NDHWGK @@ -190,7 +190,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -199,7 +199,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.in NUM_SHARDS 20 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -208,7 +208,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.in NUM_SHARDS 20 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -217,7 +217,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -226,7 +226,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -238,7 +238,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in NUM_SHARDS 11 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -247,7 +247,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in NUM_SHARDS 1 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -256,7 +256,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.in NUM_SHARDS 4 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -265,7 +265,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in NUM_SHARDS 4 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -274,7 +274,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instance.in NUM_SHARDS 1 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -283,7 +283,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instance.in NUM_SHARDS 1 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -292,7 +292,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instance.in NUM_SHARDS 5 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -301,8 +301,14 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instance.in NUM_SHARDS 12 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) -add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance ${GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP}) +#WMMA_Cshuffle_v3 +add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance + wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + ${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP} +) + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..b67e6e7c7c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); + + // Note: Commented out temporarily , might be used later. + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + // 3, + // NDHWGC, + // GKZYXC, + // Tuple, + // NDHWGK, + // ConvFwd1x1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + // 3, + // NDHWGC, + // GKZYXC, + // Tuple, + // NDHWGK, + // ConvFwd1x1S1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..0bddf9b8f3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); + + // Note: Commented out temporarily , might be used later. + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + // 3, + // NDHWGC, + // GKZYXC, + // Tuple, + // NDHWGK, + // ConvFwd1x1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + // 3, + // NDHWGC, + // GKZYXC, + // Tuple, + // NDHWGK, + // ConvFwd1x1S1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp index 22ff02676a..e47cc72b60 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp @@ -122,12 +122,12 @@ template -bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - const ck::utils::conv::ConvParam& conv_param, - int instance_index = -1) +bool profile_grouped_conv_fwd_bias_bnorm_clamp_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + int instance_index = -1) { const float floor = 0.f; const float ceil = 2048.f; @@ -198,18 +198,29 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, std::cout << "scale: " << scale.mDesc << std::endl; std::cout << "shift: " << shift.mDesc << std::endl; + // Note: For the integer initialization method (which is used for verification in the tests), I + // changed the initialization ranges such that the overall operation becomes monotone. This + // means that all multiplications are positive, and all additions are positive. Without this, + // the outelementop can make small relative errors arbitrarily large by shifting them toward + // zero. In this specific case this should not be an issue, since small integer inputs should + // lead to exact outputs from the gemm. However, this is not the case on RDNA3, where integer + // inputs can lead to slightly off-integer outputs. This is another issue to investigate, but it + // remains the case that the outelementop blowing up tiny errors is not reasonable, so changing + // the operation to monotone for now. If we want to move away from monotone we would need to + // have a proper error propagation analysis, which is much more complicated. switch(init_method) { case 0: break; case 1: - input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + input.GenerateTensorValue(GeneratorTensor_2{0, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{0, 5}); - bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{0, 5}); + // Mean is negative because this is subtracted. + mean.GenerateTensorValue(GeneratorTensor_2{-5, 0}); variance.GenerateTensorValue(GeneratorTensor_2{0, 5}); - scale.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - shift.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + scale.GenerateTensorValue(GeneratorTensor_2{0, 5}); + shift.GenerateTensorValue(GeneratorTensor_2{0, 5}); break; default: input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 3379fd15d1..012d6e1502 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -100,6 +100,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp) @@ -240,6 +241,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance) diff --git a/profiler/src/profile_grouped_conv_fwd_bias_bnorm_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_bnorm_clamp.cpp new file mode 100644 index 0000000000..179317bb28 --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_bias_bnorm_clamp.cpp @@ -0,0 +1,202 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 + F32_F32_F32_TF32, // 8 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_bias_bnorm_clamp" +#define OP_DESC "Grouped Convolution Forward+Bias+Bnorm+Clamp" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8\n" + << " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_bias_bnorm_clamp(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + using TF32 = ck::tf32_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = ck::profiler::profile_grouped_conv_fwd_bias_bnorm_clamp_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_bias_bnorm_clamp); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ef2ac098ac..b0b5f1c82f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -33,6 +33,7 @@ set(REGRESSION_TESTS test_convnd_fwd test_convnd_bwd_data test_grouped_convnd_fwd + test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_scaleadd_ab test_grouped_convnd_bwd_weight test_softmax_rank3 diff --git a/test/grouped_convnd_fwd_activation/CMakeLists.txt b/test/grouped_convnd_fwd_activation/CMakeLists.txt index e87ef77e6d..4808f82101 100644 --- a/test/grouped_convnd_fwd_activation/CMakeLists.txt +++ b/test/grouped_convnd_fwd_activation/CMakeLists.txt @@ -1,15 +1,6 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -if(GPU_TARGETS MATCHES "gfx9|gfx12") - #Fail on gfx11 CI but fail to reproduce it in local, disable it temporary - add_gtest_executable(test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_bias_bnorm_clamp.cpp) - target_link_libraries(test_grouped_convnd_fwd_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) - - add_gtest_executable(test_grouped_convnd_fwd_gk_bias_bnorm_clamp test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp) - target_link_libraries(test_grouped_convnd_fwd_gk_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) -endif() - if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp) target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) @@ -26,4 +17,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_fwd_scale test_grouped_convnd_fwd_scale.cpp) target_link_libraries(test_grouped_convnd_fwd_scale PRIVATE utility device_grouped_conv3d_fwd_scale_instance) + + add_gtest_executable(test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_bias_bnorm_clamp.cpp) + target_link_libraries(test_grouped_convnd_fwd_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) + + add_gtest_executable(test_grouped_convnd_fwd_gk_bias_bnorm_clamp test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp) + target_link_libraries(test_grouped_convnd_fwd_gk_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) endif() diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_bnorm_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_bnorm_clamp.cpp index c54b218739..93007131ab 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_bnorm_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_bnorm_clamp.cpp @@ -39,23 +39,24 @@ class TestGroupedConvndFwd : public ::testing::Test continue; } auto& param = conv_params[i]; - pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( - true, // do_verification - 1, // init_method: integer value - false, // do_log - false, // time_kernel - param, - instance_index); + pass = pass && + ck::profiler::profile_grouped_conv_fwd_bias_bnorm_clamp_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + instance_index); } EXPECT_TRUE(pass); } diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp index 8d0024354b..e17cd70d97 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp @@ -38,24 +38,23 @@ class TestGroupedConvndFwd : public ::testing::Test continue; } auto& param = conv_params[i]; - pass = pass && - ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( - true, // do_verification - 1, // init_method: integer value - false, // do_log - false, // time_kernel - param, - instance_index); + pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_bnorm_clamp_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, + DataType, + DataType, + DataType, + DataType, + DataType, + IndexType, + true /*ElementwiseGK*/>(true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + instance_index); } EXPECT_TRUE(pass); }