Grouped conv_fwd_bias_bnorm_clamp instances and tests (#3525)

* Added bias_bnorm_clamp instances.

* fwd_bias_bnorm_clamp comp instances

* fwd_bias_bnorm_mem_inter and mem_intra instances

* fwd_bias_bnorm_merged_group_instances

* fwd_bias_bnorm_clamp_conv3d_bf16 and f16 instances

* Device level changes for fwd_bias_bnorm_clamp

* Added the test to the regression test list.

* Removed the part 2 and 2x instances

* Removed the irrelevant checks in wmma

* Refactored the instances to adapt to new device implementation

* Updated the reference and include files

* enabling tests

* Added missing profiler

* Added missing instance entry , deleted by mistake

* Reduce bias bnorm clamp instances to only a single generic one.

* Clean up cmakelists file

* clang-format

* Change bias bnorm clamp tests to use monotone initialization values to avoid tiny off-integer gemm results on RDNA3 from blowing up.

* Renaming some instance lists and add functions to be more standardized.

* Commented out non default instances.

---------

Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
This commit is contained in:
ApoorvaKalyani
2026-01-22 09:53:59 +01:00
committed by GitHub
parent 0b13697a88
commit 8daf6ea302
16 changed files with 768 additions and 108 deletions

View File

@@ -24,9 +24,10 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
using Clamp = ck::tensor_operation::element_wise::Clamp;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
using Clamp = ck::tensor_operation::element_wise::Clamp;
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
@@ -40,6 +41,25 @@ static constexpr auto ConvFwdOddC =
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
@@ -146,6 +166,25 @@ using device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4 = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -16,6 +16,10 @@
#include "grouped_convolution_forward_bias_bnorm_clamp_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -279,6 +283,59 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
// layout NHWGC/GKYXC/NHWGK
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,7 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# XDL_AND_WMMA_KERNELS
set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP)
include(ShardInstantiation)
@@ -69,15 +69,6 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
# large tensor
# NHWGC, GKYXC, NHWGK
@@ -89,7 +80,6 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
@@ -108,6 +98,15 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances
@@ -193,7 +192,7 @@ generate_sharded_instantiations(
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
@@ -325,4 +324,11 @@ generate_sharded_instantiations(
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP})
#WMMA_Cshuffle_v3
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP}
)

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1S1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
2,
NHWGC,
GKYXC,
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16, F16, F16, F16, F16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
// NHWGK,
// ConvFwd1x1S1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,8 +1,8 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP)
# XDL_AND_WMMA_KERNELS
set(GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP)
include(ShardInstantiation)
@@ -11,7 +11,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -20,7 +20,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -29,7 +29,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -38,7 +38,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -47,7 +47,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -56,7 +56,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -65,7 +65,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
@@ -74,7 +74,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
# large tensor
@@ -85,16 +85,17 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
@@ -103,7 +104,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
NUM_SHARDS 2
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
@@ -112,7 +113,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 2
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
)
@@ -124,7 +125,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
@@ -133,7 +134,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
@@ -142,7 +143,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
@@ -151,7 +152,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in
NUM_SHARDS 3
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
)
#mem
@@ -162,16 +163,15 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -180,7 +180,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
# NDHWGC, GKZYXC, NDHWGK
@@ -190,7 +190,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -199,7 +199,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -208,7 +208,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.in
NUM_SHARDS 20
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -217,7 +217,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -226,7 +226,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
@@ -238,7 +238,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in
NUM_SHARDS 11
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -247,7 +247,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in
NUM_SHARDS 1
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -256,7 +256,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -265,7 +265,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in
NUM_SHARDS 4
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -274,7 +274,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instance.in
NUM_SHARDS 1
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -283,7 +283,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instance.in
NUM_SHARDS 1
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -292,7 +292,7 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instance.in
NUM_SHARDS 5
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
@@ -301,8 +301,14 @@ generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instance.in
NUM_SHARDS 12
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance ${GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP})
#WMMA_Cshuffle_v3
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP}
)

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BF16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<BF16, BF16, BF16, BF16, BF16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1S1P0,
// Tuple<BF16, BF16, BF16, BF16, BF16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<
DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16, F16, F16, F16, F16>,
F16,
PassThrough,
PassThrough,
BiasNormalizeInInferClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<F16, F16, F16, F16, F16>,
BiasNormalizeInInferClamp>{});
// Note: Commented out temporarily , might be used later.
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
// 3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
// NDHWGK,
// ConvFwd1x1S1P0,
// Tuple<F16, F16, F16, F16, F16>,
// BiasNormalizeInInferClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -122,12 +122,12 @@ template <ck::index_t NDimSpatial,
typename BComputeType = AComputeType,
typename IndexType = ck::index_t,
bool ElementwiseGK = false>
bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
int instance_index = -1)
bool profile_grouped_conv_fwd_bias_bnorm_clamp_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
int instance_index = -1)
{
const float floor = 0.f;
const float ceil = 2048.f;
@@ -198,18 +198,29 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
std::cout << "scale: " << scale.mDesc << std::endl;
std::cout << "shift: " << shift.mDesc << std::endl;
// Note: For the integer initialization method (which is used for verification in the tests), I
// changed the initialization ranges such that the overall operation becomes monotone. This
// means that all multiplications are positive, and all additions are positive. Without this,
// the outelementop can make small relative errors arbitrarily large by shifting them toward
// zero. In this specific case this should not be an issue, since small integer inputs should
// lead to exact outputs from the gemm. However, this is not the case on RDNA3, where integer
// inputs can lead to slightly off-integer outputs. This is another issue to investigate, but it
// remains the case that the outelementop blowing up tiny errors is not reasonable, so changing
// the operation to monotone for now. If we want to move away from monotone we would need to
// have a proper error propagation analysis, which is much more complicated.
switch(init_method)
{
case 0: break;
case 1:
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{0, 5});
weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{0, 5});
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
mean.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{0, 5});
// Mean is negative because this is subtracted.
mean.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 0});
variance.GenerateTensorValue(GeneratorTensor_2<OutDataType>{0, 5});
scale.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
shift.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
scale.GenerateTensorValue(GeneratorTensor_2<OutDataType>{0, 5});
shift.GenerateTensorValue(GeneratorTensor_2<OutDataType>{0, 5});
break;
default:
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});

View File

@@ -100,6 +100,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
@@ -240,6 +241,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)

View File

@@ -0,0 +1,202 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/ignore.hpp"
#include "profiler_operation_registry.hpp"
#include <iostream>
enum struct ConvLayout
{
GNHWC_GKYXC_GNHWK, // 0
NHWGC_GKYXC_NHWGK, // 1
NGCHW_GKYXC_NGKHW, // 2
NGCHW_GKCYX_NGKHW, // 3
};
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F8_F8, // 4
BF8_BF8_F8, // 5
F8_BF8_F8, // 6
BF8_F8_F8, // 7
F32_F32_F32_TF32, // 8
};
enum struct IndexType
{
INDEX_T, // 0
LONG_INDEX_T, // 1
};
#define OP_NAME "grouped_conv_fwd_bias_bnorm_clamp"
#define OP_DESC "Grouped Convolution Forward+Bias+Bnorm+Clamp"
static void print_helper_msg()
{
std::cout
// clang-format off
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
<< " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 2: Input bf16, Weight bf16, Output bf16\n"
<< " 3: Input int8, Weight int8, Output int8\n"
<< " 4: Input fp8, Weight fp8, Output fp8\n"
<< " 5: Input bf8, Weight bf8, Output fp8\n"
<< " 6: Input fp8, Weight bf8, Output fp8\n"
<< " 7: Input bf8, Weight fp8, Output fp8\n"
<< " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n"
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n"
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
"G, K, Ho, Wo]\n"
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, "
"G, K, Ho, Wo])\n"
<< "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n"
<< "arg5: verification (0: no, 1: yes)\n"
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
// clang-format on
}
int grouped_conv_fwd_bias_bnorm_clamp(int argc, char* argv[])
{
// 8 for control, 1 for num_dim_spatial
if(argc < 10)
{
print_helper_msg();
return 1;
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
const auto index_type = static_cast<IndexType>(std::stoi(argv[4]));
const bool do_verification = std::stoi(argv[5]);
const int init_method = std::stoi(argv[6]);
const bool do_log = std::stoi(argv[7]);
const bool time_kernel = std::stoi(argv[8]);
const int num_dim_spatial = std::stoi(argv[9]);
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
{
print_helper_msg();
return 1;
}
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv);
if(index_type != IndexType::INDEX_T)
{
std::cout << "this indexing data type is not implemented" << std::endl;
return 1;
}
using F32 = float;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using TF32 = ck::tf32_t;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
auto profile = [&](auto num_dim_spatial_tmp,
auto in_layout,
auto wei_layout,
auto out_layout,
auto in_type,
auto wei_type,
auto out_type,
auto a_compute_type,
auto b_compute_type) {
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using OutLayout = decltype(out_layout);
using InDataType = decltype(in_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using AComputeType = decltype(a_compute_type);
using BComputeType = decltype(b_compute_type);
bool pass = ck::profiler::profile_grouped_conv_fwd_bias_bnorm_clamp_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
AComputeType,
BComputeType>(
do_verification, init_method, do_log, time_kernel, params);
return pass ? 0 : 1;
};
if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_bias_bnorm_clamp);

View File

@@ -33,6 +33,7 @@ set(REGRESSION_TESTS
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_fwd_bias_bnorm_clamp
test_grouped_convnd_fwd_scaleadd_ab
test_grouped_convnd_bwd_weight
test_softmax_rank3

View File

@@ -1,15 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9|gfx12")
#Fail on gfx11 CI but fail to reproduce it in local, disable it temporary
add_gtest_executable(test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_bias_bnorm_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
add_gtest_executable(test_grouped_convnd_fwd_gk_bias_bnorm_clamp test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_gk_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
endif()
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
@@ -26,4 +17,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_grouped_convnd_fwd_scale test_grouped_convnd_fwd_scale.cpp)
target_link_libraries(test_grouped_convnd_fwd_scale PRIVATE utility device_grouped_conv3d_fwd_scale_instance)
add_gtest_executable(test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_bias_bnorm_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
add_gtest_executable(test_grouped_convnd_fwd_gk_bias_bnorm_clamp test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_gk_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
endif()

View File

@@ -39,23 +39,24 @@ class TestGroupedConvndFwd : public ::testing::Test
continue;
}
auto& param = conv_params[i];
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
DataType,
DataType,
IndexType,
false /*BiasGK*/>(
true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
param,
instance_index);
pass = pass &&
ck::profiler::profile_grouped_conv_fwd_bias_bnorm_clamp_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
DataType,
DataType,
IndexType,
false /*BiasGK*/>(
true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
param,
instance_index);
}
EXPECT_TRUE(pass);
}

View File

@@ -38,24 +38,23 @@ class TestGroupedConvndFwd : public ::testing::Test
continue;
}
auto& param = conv_params[i];
pass = pass &&
ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
DataType,
DataType,
IndexType,
true /*ElementwiseGK*/>(
true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
param,
instance_index);
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_bnorm_clamp_impl<
NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
DataType,
DataType,
IndexType,
true /*ElementwiseGK*/>(true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
param,
instance_index);
}
EXPECT_TRUE(pass);
}