Grouped conv_fwd_bias_bnorm_clamp instances and tests (#3525)

* Added bias_bnorm_clamp instances.

* fwd_bias_bnorm_clamp comp instances

* fwd_bias_bnorm_mem_inter and mem_intra instances

* fwd_bias_bnorm_merged_group_instances

* fwd_bias_bnorm_clamp_conv3d_bf16 and f16 instances

* Device level changes for fwd_bias_bnorm_clamp

* Added the test to the regression test list.

* Removed the part 2 and 2x instances

* Removed the irrelevant checks in wmma

* Refactored the instances to adapt to new device implementation

* Updated the reference and include files

* enabling tests

* Added missing profiler

* Added missing instance entry , deleted by mistake

* Reduce bias bnorm clamp instances to only a single generic one.

* Clean up cmakelists file

* clang-format

* Change bias bnorm clamp tests to use monotone initialization values to avoid tiny off-integer gemm results on RDNA3 from blowing up.

* Renaming some instance lists and add functions to be more standardized.

* Commented out non default instances.

---------

Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>

[ROCm/composable_kernel commit: 8daf6ea302]
This commit is contained in:
ApoorvaKalyani
2026-01-22 09:53:59 +01:00
committed by GitHub
parent f6fac4cea6
commit ec0f5c82ca
16 changed files with 768 additions and 108 deletions

View File

@@ -1,7 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# XDL_AND_WMMA_KERNELS
set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP)
include(ShardInstantiation)
@@ -69,15 +69,6 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
# large tensor
# NHWGC, GKYXC, NHWGK
@@ -89,7 +80,6 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
@@ -108,6 +98,15 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances
@@ -193,7 +192,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
@@ -325,4 +324,11 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP})
#WMMA_Cshuffle_v3
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP}
)

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1S1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16, F16, F16, F16, F16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1S1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,8 +1,8 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP)
# XDL_AND_WMMA_KERNELS
set(GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP)
include(ShardInstantiation)
@@ -11,7 +11,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -20,7 +20,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -29,7 +29,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -38,7 +38,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -47,7 +47,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -56,7 +56,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -65,7 +65,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -74,7 +74,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
# large tensor
@@ -85,16 +85,17 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
@@ -103,7 +104,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
NUM_SHARDS 2
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
@@ -112,7 +113,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 2
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
@@ -124,7 +125,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
@@ -133,7 +134,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
@@ -142,7 +143,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
@@ -151,7 +152,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
#mem
@@ -162,16 +163,15 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -180,7 +180,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
# NDHWGC, GKZYXC, NDHWGK
@@ -190,7 +190,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -199,7 +199,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -208,7 +208,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -217,7 +217,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -226,7 +226,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -238,7 +238,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in
NUM_SHARDS 11
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -247,7 +247,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in
NUM_SHARDS 1
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -256,7 +256,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -265,7 +265,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -274,7 +274,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instance.in
NUM_SHARDS 1
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -283,7 +283,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instance.in
NUM_SHARDS 1
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -292,7 +292,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instance.in
NUM_SHARDS 5
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -301,8 +301,14 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instance.in
NUM_SHARDS 12
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance ${GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP})
#WMMA_Cshuffle_v3
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP}
)

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1S1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<F16, F16, F16, F16, F16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1S1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck