mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
conv:tf32:add missed instances (#3081)
* conv:tf32:add missed instances
[ROCm/composable_kernel commit: 6bbc05e1bd]
This commit is contained in:
@@ -13,6 +13,7 @@ set(GROUPED_CONV2D_FWD
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,7 +2,7 @@
|
||||
set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP)
|
||||
include(ShardInstantiation)
|
||||
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
|
||||
@@ -11,7 +11,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances
|
||||
@@ -20,7 +20,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances
|
||||
@@ -29,7 +29,16 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances
|
||||
@@ -38,7 +47,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances
|
||||
@@ -47,7 +56,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances
|
||||
@@ -56,9 +65,19 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
# large tensor
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances
|
||||
@@ -67,7 +86,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
|
||||
@@ -76,7 +95,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances
|
||||
@@ -97,7 +116,7 @@ generate_sharded_instantiations(
|
||||
|
||||
# merged groups
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
|
||||
@@ -106,7 +125,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
|
||||
@@ -115,7 +134,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
|
||||
@@ -135,7 +154,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
#mem
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances
|
||||
@@ -144,7 +163,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
|
||||
@@ -153,7 +172,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
|
||||
@@ -173,7 +192,7 @@ generate_sharded_instantiations(
|
||||
)
|
||||
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances
|
||||
@@ -182,7 +201,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances
|
||||
@@ -191,7 +210,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
|
||||
@@ -212,7 +231,7 @@ generate_sharded_instantiations(
|
||||
|
||||
#comp
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances
|
||||
@@ -221,7 +240,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances
|
||||
@@ -230,7 +249,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances
|
||||
@@ -257,7 +276,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances
|
||||
@@ -266,7 +285,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances
|
||||
@@ -275,7 +294,7 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances =
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances_shard(
|
||||
device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances =
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances_shard(
|
||||
device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -23,6 +23,7 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance
|
||||
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_16x16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -23,6 +23,7 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance
|
||||
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_16x16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,6 +2,7 @@
|
||||
set(GROUPED_CONV3D_BWD_DATA_BILINEAR
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp)
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_bwd_data_bilinear_instance ${GROUPED_CONV3D_BWD_DATA_BILINEAR})
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGC>,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bilinear_f32_tf32_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGC>,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_bilinear_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGC>,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -2,6 +2,7 @@
|
||||
set(GROUPED_CONV3D_BWD_DATA_BILINEAR
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp)
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_bwd_data_scale_instance ${GROUPED_CONV3D_BWD_DATA_BILINEAR})
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_scale_f32_tf32_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_bwd_data_xdl_scale_f32_tf32_instances<
|
||||
3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -11,6 +11,7 @@ set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwdDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -65,6 +65,15 @@ generate_sharded_instantiations(
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
# large tensor
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances =
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp,
|
||||
TF32,
|
||||
TF32>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances_shard(
|
||||
device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -18,6 +18,7 @@ set(GROUPED_CONV3D_FWD
|
||||
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_16x16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<F32>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -18,6 +18,7 @@ set(GROUPED_CONV3D_FWD
|
||||
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_16x16_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp,
|
||||
TF32,
|
||||
TF32>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_tf32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user