mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
* Revert "Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)"
This reverts commit 82da15ffa430a297fb072d0a15b3ada5753f69b1.
* fix compile error on gf12x
* only run tf32 example on gfx942
* only build tf32 instance on gfx942
* ckProfiler:only support tf32 in gfx942
* delete unuseful messages
[ROCm/composable_kernel commit: dd7af118d7]
This commit is contained in:
@@ -94,6 +94,11 @@ function(add_instance_library INSTANCE_NAME)
|
||||
message(DEBUG "removing gemm_universal_preshuffle_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
# Only build tf32 instances for gfx942
|
||||
if(NOT INST_TARGETS MATCHES "gfx942" AND source_name MATCHES "_tf32_")
|
||||
message(DEBUG "removing tf32 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
@@ -441,7 +446,7 @@ if(CK_DEVICE_MHA_INSTANCES AND NOT MIOPEN_REQ_LIBS_ONLY AND BUILD_MHA_LIB)
|
||||
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
|
||||
target_compile_features(device_mha_operations PUBLIC)
|
||||
set_target_properties(device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
|
||||
rocm_install(TARGETS device_mha_operations
|
||||
EXPORT device_mha_operationsTargets)
|
||||
rocm_install(EXPORT device_mha_operationsTargets
|
||||
|
||||
@@ -7,6 +7,7 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp
|
||||
@@ -34,7 +35,7 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp
|
||||
|
||||
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,7 +2,7 @@
|
||||
set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP)
|
||||
include(ShardInstantiation)
|
||||
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
@@ -11,7 +11,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
@@ -20,7 +20,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
@@ -29,7 +29,16 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances
|
||||
@@ -38,7 +47,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances
|
||||
@@ -47,7 +56,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances
|
||||
@@ -58,7 +67,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
# large tensor
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
@@ -67,7 +76,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
@@ -76,7 +85,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
@@ -87,7 +96,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
# merged groups
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
@@ -96,7 +105,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
@@ -105,7 +114,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
@@ -116,7 +125,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
#mem
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
|
||||
@@ -125,7 +134,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
|
||||
@@ -134,7 +143,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
|
||||
@@ -144,7 +153,7 @@ generate_sharded_instantiations(
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
|
||||
@@ -153,7 +162,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances
|
||||
@@ -162,7 +171,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances
|
||||
@@ -173,7 +182,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
#comp
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
|
||||
@@ -182,7 +191,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
|
||||
@@ -191,7 +200,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
|
||||
@@ -200,7 +209,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances
|
||||
@@ -209,7 +218,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances
|
||||
@@ -218,7 +227,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances
|
||||
@@ -227,7 +236,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances =
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard(
|
||||
device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp
|
||||
)
|
||||
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD})
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp
|
||||
xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp
|
||||
)
|
||||
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD})
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user