Grouped conv_fwd_bias_bnorm_clamp instances and tests (#3525)

* Added bias_bnorm_clamp instances.

* fwd_bias_bnorm_clamp comp instances

* fwd_bias_bnorm_mem_inter and mem_intra instances

* fwd_bias_bnorm_merged_group_instances

* fwd_bias_bnorm_clamp_conv3d_bf16 and f16 instances

* Device level changes for fwd_bias_bnorm_clamp

* Added the test to the regression test list.

* Removed the part 2 and 2x instances

* Removed the irrelevant checks in wmma

* Refactored the instances to adapt to new device implementation

* Updated the reference and include files

* enabling tests

* Added missing profiler

* Added missing instance entry , deleted by mistake

* Reduce bias bnorm clamp instances to only a single generic one.

* Clean up cmakelists file

* clang-format

* Change bias bnorm clamp tests to use monotone initialization values to avoid tiny off-integer gemm results on RDNA3 from blowing up.

* Renaming some instance lists and add functions to be more standardized.

* Commented out non default instances.

---------

Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
This commit is contained in:
ApoorvaKalyani
2026-01-22 09:53:59 +01:00
committed by GitHub
parent 0b13697a88
commit 8daf6ea302
16 changed files with 768 additions and 108 deletions

View File

@@ -24,9 +24,10 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
using Clamp = ck::tensor_operation::element_wise::Clamp;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
using Clamp = ck::tensor_operation::element_wise::Clamp;
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
@@ -40,6 +41,25 @@ static constexpr auto ConvFwdOddC =
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
@@ -146,6 +166,25 @@ using device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4 = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -16,6 +16,10 @@
#include "grouped_convolution_forward_bias_bnorm_clamp_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -279,6 +283,59 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
// layout NHWGC/GKYXC/NHWGK
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,7 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# XDL_AND_WMMA_KERNELS
set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP)
include(ShardInstantiation)
@@ -69,15 +69,6 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
# large tensor
# NHWGC, GKYXC, NHWGK
@@ -89,7 +80,6 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
@@ -108,6 +98,15 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances
@@ -193,7 +192,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
@@ -325,4 +324,11 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP})
#WMMA_Cshuffle_v3
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP}
)

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1S1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16, F16, F16, F16, F16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1S1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,8 +1,8 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP)
# XDL_AND_WMMA_KERNELS
set(GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP)
include(ShardInstantiation)
@@ -11,7 +11,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -20,7 +20,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -29,7 +29,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -38,7 +38,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -47,7 +47,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -56,7 +56,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -65,7 +65,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -74,7 +74,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
# large tensor
@@ -85,16 +85,17 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
@@ -103,7 +104,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
NUM_SHARDS 2
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
@@ -112,7 +113,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 2
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
@@ -124,7 +125,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
@@ -133,7 +134,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
@@ -142,7 +143,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
@@ -151,7 +152,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
#mem
@@ -162,16 +163,15 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -180,7 +180,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
# NDHWGC, GKZYXC, NDHWGK
@@ -190,7 +190,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -199,7 +199,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -208,7 +208,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -217,7 +217,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -226,7 +226,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -238,7 +238,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in
NUM_SHARDS 11
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -247,7 +247,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in
NUM_SHARDS 1
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -256,7 +256,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -265,7 +265,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -274,7 +274,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instance.in
NUM_SHARDS 1
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -283,7 +283,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instance.in
NUM_SHARDS 1
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -292,7 +292,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instance.in
NUM_SHARDS 5
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -301,8 +301,14 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instance.in
NUM_SHARDS 12
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance ${GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP})
#WMMA_Cshuffle_v3
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP}
)

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1S1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<F16, F16, F16, F16, F16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1S1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck