mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Grouped conv_fwd_bias_bnorm_clamp instances and tests (#3525)
* Added bias_bnorm_clamp instances.
* fwd_bias_bnorm_clamp comp instances
* fwd_bias_bnorm_mem_inter and mem_intra instances
* fwd_bias_bnorm_merged_group_instances
* fwd_bias_bnorm_clamp_conv3d_bf16 and f16 instances
* Device level changes for fwd_bias_bnorm_clamp
* Added the test to the regression test list.
* Removed the part 2 and 2x instances
* Removed the irrelevant checks in wmma
* Refactored the instances to adapt to new device implementation
* Updated the reference and include files
* enabling tests
* Added missing profiler
* Added missing instance entry , deleted by mistake
* Reduce bias bnorm clamp instances to only a single generic one.
* Clean up cmakelists file
* clang-format
* Change bias bnorm clamp tests to use monotone initialization values to avoid tiny off-integer gemm results on RDNA3 from blowing up.
* Renaming some instance lists and add functions to be more standardized.
* Commented out non default instances.
---------
Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
[ROCm/composable_kernel commit: 8daf6ea302]
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP)
|
||||
include(ShardInstantiation)
|
||||
|
||||
@@ -69,15 +69,6 @@ generate_sharded_instantiations(
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
# large tensor
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
@@ -89,7 +80,6 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
|
||||
@@ -108,6 +98,15 @@ generate_sharded_instantiations(
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances
|
||||
@@ -193,7 +192,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
|
||||
@@ -325,4 +324,11 @@ generate_sharded_instantiations(
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP})
|
||||
#WMMA_Cshuffle_v3
|
||||
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance
|
||||
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BiasNormalizeInInferClamp>{});
|
||||
|
||||
// Note: Commented out temporarily , might be used later.
|
||||
|
||||
// add_device_operation_instances(instances,
|
||||
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
|
||||
// 2,
|
||||
// NHWGC,
|
||||
// GKYXC,
|
||||
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
// NHWGK,
|
||||
// ConvFwd1x1P0,
|
||||
// Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
// BiasNormalizeInInferClamp>{});
|
||||
|
||||
// add_device_operation_instances(instances,
|
||||
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
|
||||
// 2,
|
||||
// NHWGC,
|
||||
// GKYXC,
|
||||
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
// NHWGK,
|
||||
// ConvFwd1x1S1P0,
|
||||
// Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
// BiasNormalizeInInferClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,65 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
BiasNormalizeInInferClamp>{});
|
||||
|
||||
// Note: Commented out temporarily , might be used later.
|
||||
|
||||
// add_device_operation_instances(instances,
|
||||
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
|
||||
// 2,
|
||||
// NHWGC,
|
||||
// GKYXC,
|
||||
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
// NHWGK,
|
||||
// ConvFwd1x1P0,
|
||||
// Tuple<F16, F16, F16, F16, F16>,
|
||||
// BiasNormalizeInInferClamp>{});
|
||||
|
||||
// add_device_operation_instances(instances,
|
||||
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
|
||||
// 2,
|
||||
// NHWGC,
|
||||
// GKYXC,
|
||||
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
// NHWGK,
|
||||
// ConvFwd1x1S1P0,
|
||||
// Tuple<F16, F16, F16, F16, F16>,
|
||||
// BiasNormalizeInInferClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP)
|
||||
# XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP)
|
||||
include(ShardInstantiation)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
@@ -47,7 +47,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
@@ -56,7 +56,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
# large tensor
|
||||
@@ -85,16 +85,17 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
@@ -103,7 +104,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
|
||||
NUM_SHARDS 2
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
@@ -112,7 +113,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 2
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
@@ -124,7 +125,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
@@ -133,7 +134,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
@@ -142,7 +143,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
@@ -151,7 +152,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
#mem
|
||||
@@ -162,16 +163,15 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
@@ -180,7 +180,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
@@ -190,7 +190,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
@@ -199,7 +199,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
@@ -208,7 +208,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
@@ -217,7 +217,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
@@ -226,7 +226,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
@@ -238,7 +238,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in
|
||||
NUM_SHARDS 11
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
@@ -247,7 +247,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in
|
||||
NUM_SHARDS 1
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
@@ -256,7 +256,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
@@ -265,7 +265,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
@@ -274,7 +274,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instance.in
|
||||
NUM_SHARDS 1
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
@@ -283,7 +283,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instance.in
|
||||
NUM_SHARDS 1
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
@@ -292,7 +292,7 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instance.in
|
||||
NUM_SHARDS 5
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
@@ -301,8 +301,14 @@ generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instance.in
|
||||
NUM_SHARDS 12
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance ${GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP})
|
||||
#WMMA_Cshuffle_v3
|
||||
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance
|
||||
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP}
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BiasNormalizeInInferClamp>{});
|
||||
|
||||
// Note: Commented out temporarily , might be used later.
|
||||
|
||||
// add_device_operation_instances(instances,
|
||||
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
|
||||
// 3,
|
||||
// NDHWGC,
|
||||
// GKZYXC,
|
||||
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
// NDHWGK,
|
||||
// ConvFwd1x1P0,
|
||||
// Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
// BiasNormalizeInInferClamp>{});
|
||||
|
||||
// add_device_operation_instances(instances,
|
||||
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
|
||||
// 3,
|
||||
// NDHWGC,
|
||||
// GKZYXC,
|
||||
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
// NDHWGK,
|
||||
// ConvFwd1x1S1P0,
|
||||
// Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
// BiasNormalizeInInferClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,65 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
BiasNormalizeInInferClamp>{});
|
||||
|
||||
// Note: Commented out temporarily , might be used later.
|
||||
|
||||
// add_device_operation_instances(instances,
|
||||
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
|
||||
// 3,
|
||||
// NDHWGC,
|
||||
// GKZYXC,
|
||||
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
// NDHWGK,
|
||||
// ConvFwd1x1P0,
|
||||
// Tuple<F16, F16, F16, F16, F16>,
|
||||
// BiasNormalizeInInferClamp>{});
|
||||
|
||||
// add_device_operation_instances(instances,
|
||||
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
|
||||
// 3,
|
||||
// NDHWGC,
|
||||
// GKZYXC,
|
||||
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
// NDHWGK,
|
||||
// ConvFwd1x1S1P0,
|
||||
// Tuple<F16, F16, F16, F16, F16>,
|
||||
// BiasNormalizeInInferClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user