mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Merge commit '5328b232b25cdf0989ba9ec5dbbda99e4933587c' into develop
This commit is contained in:
@@ -113,29 +113,30 @@ using GK_Tuple = ck::Tuple<G_K>;
|
||||
using GK_GK_Tuple = ck::Tuple<G_K, G_K>;
|
||||
|
||||
// pointwise functor
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Relu = ck::tensor_operation::element_wise::Relu;
|
||||
using TanH = ck::tensor_operation::element_wise::TanH;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using Gelu = ck::tensor_operation::element_wise::Gelu;
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Relu = ck::tensor_operation::element_wise::Relu;
|
||||
using TanH = ck::tensor_operation::element_wise::TanH;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using Gelu = ck::tensor_operation::element_wise::Gelu;
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
|
||||
template <typename Activation>
|
||||
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
|
||||
|
||||
@@ -32,9 +32,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -32,9 +32,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -24,9 +24,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -32,9 +32,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -24,9 +24,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#include "grouped_convolution_forward_bias_bnorm_clamp_xdl.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DLayouts,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DDataTypes,
|
||||
typename AComputeType,
|
||||
typename BComputeType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DLayouts,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DDataTypes,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::BiasNormalizeInInferClamp,
|
||||
AComputeType,
|
||||
BComputeType>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DLayouts,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DDataTypes,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::BiasNormalizeInInferClamp,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
// layout NHWGC/GKYXC/NHWGK
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
// layout NDHWGC/GKZYXC/NDHWGK
|
||||
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,776 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user