mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '5328b232b25cdf0989ba9ec5dbbda99e4933587c' into develop
This commit is contained in:
@@ -562,6 +562,58 @@ struct NormalizeInInfer
|
||||
double epsilon_;
|
||||
};
|
||||
|
||||
// used by Conv+Bias+BatchNorm+Clamp inference
|
||||
struct BiasNormalizeInInferClamp
|
||||
{
|
||||
BiasNormalizeInInferClamp(float floor = 0.f,
|
||||
float ceil = NumericLimits<float>::Max(),
|
||||
float epsilon = 1e-4)
|
||||
: clamp_(floor, ceil), epsilon_(epsilon)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y,
|
||||
const T& x,
|
||||
const T& bias,
|
||||
const T& mean,
|
||||
const T& variance,
|
||||
const T& gamma,
|
||||
const T& beta) const
|
||||
{
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
float tmp_x = type_convert<float>(x) + type_convert<float>(bias);
|
||||
|
||||
float tmp_y =
|
||||
((tmp_x - type_convert<float>(mean)) / sqrt(type_convert<float>(variance) + epsilon_)) *
|
||||
type_convert<float>(gamma) +
|
||||
type_convert<float>(beta);
|
||||
clamp_(tmp_y, tmp_y);
|
||||
y = type_convert<T>(tmp_y);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()(float& y,
|
||||
const float& x,
|
||||
const float& bias,
|
||||
const float& mean,
|
||||
const float& variance,
|
||||
const float& gamma,
|
||||
const float& beta) const
|
||||
{
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
float tmp_y = (((x + bias) - mean) / sqrt(variance + epsilon_)) * gamma + beta;
|
||||
clamp_(y, tmp_y);
|
||||
};
|
||||
|
||||
Clamp clamp_;
|
||||
float epsilon_;
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
struct UnaryTypeConvert;
|
||||
|
||||
|
||||
@@ -113,29 +113,30 @@ using GK_Tuple = ck::Tuple<G_K>;
|
||||
using GK_GK_Tuple = ck::Tuple<G_K, G_K>;
|
||||
|
||||
// pointwise functor
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Relu = ck::tensor_operation::element_wise::Relu;
|
||||
using TanH = ck::tensor_operation::element_wise::TanH;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using Gelu = ck::tensor_operation::element_wise::Gelu;
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Relu = ck::tensor_operation::element_wise::Relu;
|
||||
using TanH = ck::tensor_operation::element_wise::TanH;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using Gelu = ck::tensor_operation::element_wise::Gelu;
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
|
||||
template <typename Activation>
|
||||
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
|
||||
|
||||
@@ -32,9 +32,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -32,9 +32,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -24,9 +24,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -32,9 +32,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -24,9 +24,10 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#include "grouped_convolution_forward_bias_bnorm_clamp_xdl.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DLayouts,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DDataTypes,
|
||||
typename AComputeType,
|
||||
typename BComputeType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DLayouts,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DDataTypes,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::BiasNormalizeInInferClamp,
|
||||
AComputeType,
|
||||
BComputeType>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DLayouts,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DDataTypes,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::BiasNormalizeInInferClamp,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
// layout NHWGC/GKYXC/NHWGK
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
// layout NDHWGC/GKZYXC/NDHWGK
|
||||
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,776 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,240 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP)
|
||||
include(ShardInstantiation)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
# large tensor
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.in
|
||||
NUM_SHARDS 2
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
# merged groups
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
#mem
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
#comp
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.in
|
||||
NUM_SHARDS 11
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.in
|
||||
NUM_SHARDS 1
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.in
|
||||
NUM_SHARDS 1
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.in
|
||||
NUM_SHARDS 1
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.in
|
||||
NUM_SHARDS 5
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.in
|
||||
NUM_SHARDS 12
|
||||
SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP})
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_16x16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances_shard(device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC, Tuple<NHWGK, NHWGK, NHWGK, NHWGK, NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwd3x3,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,240 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP)
|
||||
include(ShardInstantiation)
|
||||
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances
|
||||
TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl
|
||||
)
|
||||
# large tensor
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
|
||||
NUM_SHARDS 2
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor
|
||||
)
|
||||
# merged groups
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.in
|
||||
NUM_SHARDS 3
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups
|
||||
)
|
||||
#mem
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.in
|
||||
NUM_SHARDS 20
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances
|
||||
TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.in
|
||||
NUM_SHARDS 16
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
|
||||
)
|
||||
#comp
|
||||
# NDHWGC, GKZYXC, NDHWGK
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in
|
||||
NUM_SHARDS 11
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in
|
||||
NUM_SHARDS 1
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.in
|
||||
NUM_SHARDS 4
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instance.in
|
||||
NUM_SHARDS 1
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instance.in
|
||||
NUM_SHARDS 1
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instance.in
|
||||
NUM_SHARDS 5
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
|
||||
generate_sharded_instantiations(
|
||||
INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances
|
||||
TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instance.in
|
||||
NUM_SHARDS 12
|
||||
SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP
|
||||
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance ${GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP})
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_comp_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_16x16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Interwave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Interwave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Intrawave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16, BF16, BF16, BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd3x3,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd3x3,Tuple<BF16, BF16, BF16, BF16, BF16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16, F16, F16, F16, F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd3x3,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd3x3,Tuple<F16, F16, F16, F16, F16>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
#include "ck/utility/filter_tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances = std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Tuple<F32, F32, F32, F32, F32>,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BiasNormalizeInInferClamp>>>;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
template <int Shards, int ShardIndex>
|
||||
void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances_shard(device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances, ck::util::filter_tuple_by_modulo_t<
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC, Tuple<NDHWGK, NDHWGK, NDHWGK, NDHWGK, NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwd3x3,Tuple<F32, F32, F32, F32, F32>, BiasNormalizeInInferClamp>,
|
||||
Shards,
|
||||
ShardIndex>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,427 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
// NOTE: Usage of NHWGK layout for GK bias is a workaround. This test is to
|
||||
// just keep such implementation valid.
|
||||
// TODO: Add possiblity to pass GK layout and GK lengths for bias and reuse
|
||||
// the same instances.
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
auto get_elementwise_desc(ck::index_t G, ck::index_t K)
|
||||
{
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
return HostTensorDescriptor({G, 1, K, 1}, {K, 0, 1, 0});
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
return HostTensorDescriptor({G, 1, K, 1, 1}, {K, 0, 1, 0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({G, 1, K, 1, 1, 1}, {K, 0, 1, 0, 0, 0});
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial, typename OutDataType>
|
||||
void ref_bnorm_clamp_infer(Tensor<OutDataType>& out,
|
||||
Tensor<OutDataType>& in,
|
||||
Tensor<OutDataType>& mean,
|
||||
Tensor<OutDataType>& variance,
|
||||
Tensor<OutDataType>& scale,
|
||||
Tensor<OutDataType>& shift,
|
||||
const float floor,
|
||||
const float ceil,
|
||||
const float epsilon)
|
||||
{
|
||||
|
||||
auto func = [&](auto... idxs) {
|
||||
const float x = type_convert<float>(in(idxs...));
|
||||
|
||||
const float invVariance =
|
||||
type_convert<float>(1.0f) / std::sqrt(epsilon + type_convert<float>(variance(idxs...)));
|
||||
|
||||
const float norm_x = (x - type_convert<float>(mean(idxs...))) * invVariance;
|
||||
|
||||
float y =
|
||||
type_convert<float>(scale(idxs...)) * norm_x + type_convert<float>(shift(idxs...));
|
||||
|
||||
Clamp{floor, ceil}(y, y);
|
||||
|
||||
out(idxs...) = type_convert<OutDataType>(y);
|
||||
};
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
make_ParallelTensorFunctor(func,
|
||||
out.GetLengths()[0],
|
||||
out.GetLengths()[1],
|
||||
out.GetLengths()[2],
|
||||
out.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
make_ParallelTensorFunctor(func,
|
||||
out.GetLengths()[0],
|
||||
out.GetLengths()[1],
|
||||
out.GetLengths()[2],
|
||||
out.GetLengths()[3],
|
||||
out.GetLengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
make_ParallelTensorFunctor(func,
|
||||
out.GetLengths()[0],
|
||||
out.GetLengths()[1],
|
||||
out.GetLengths()[2],
|
||||
out.GetLengths()[3],
|
||||
out.GetLengths()[4],
|
||||
out.GetLengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType,
|
||||
typename IndexType = ck::index_t,
|
||||
bool ElementwiseGK = false>
|
||||
bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
const float floor = 0.f;
|
||||
const float ceil = 2048.f;
|
||||
const float epsilon = 1e-4;
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{floor, ceil, epsilon};
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
const index_t G = conv_param.G_;
|
||||
const index_t K = conv_param.K_;
|
||||
|
||||
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_strides{};
|
||||
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_lengths{};
|
||||
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_strides{};
|
||||
std::array<IndexType, NDimSpatial + 3> d_g_n_k_wos_strides{};
|
||||
std::array<IndexType, NDimSpatial> conv_filter_strides{};
|
||||
std::array<IndexType, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<IndexType, NDimSpatial> input_left_pads{};
|
||||
std::array<IndexType, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
Tensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
|
||||
Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
|
||||
const auto elementwise_desc =
|
||||
ElementwiseGK ? get_elementwise_desc<NDimSpatial>(G, K) : out_g_n_k_wos_desc;
|
||||
|
||||
Tensor<OutDataType> bias(elementwise_desc);
|
||||
Tensor<OutDataType> mean(elementwise_desc);
|
||||
Tensor<OutDataType> variance(elementwise_desc);
|
||||
Tensor<OutDataType> scale(elementwise_desc);
|
||||
Tensor<OutDataType> shift(elementwise_desc);
|
||||
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << host_output.mDesc << std::endl;
|
||||
|
||||
std::cout << "bias: " << bias.mDesc << std::endl;
|
||||
std::cout << "mean: " << mean.mDesc << std::endl;
|
||||
std::cout << "variance: " << variance.mDesc << std::endl;
|
||||
std::cout << "scale: " << scale.mDesc << std::endl;
|
||||
std::cout << "shift: " << shift.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
|
||||
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
mean.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
variance.GenerateTensorValue(GeneratorTensor_2<OutDataType>{0, 5});
|
||||
scale.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
shift.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
|
||||
weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||
|
||||
bias.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
|
||||
mean.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
|
||||
variance.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0, 0.5});
|
||||
scale.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
|
||||
shift.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
|
||||
|
||||
const std::size_t elementwise_dev_buf_size =
|
||||
ElementwiseGK ? sizeof(OutDataType) * G * K
|
||||
: sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize();
|
||||
DeviceMem bias_device_buf(elementwise_dev_buf_size);
|
||||
DeviceMem mean_device_buf(elementwise_dev_buf_size);
|
||||
DeviceMem variance_device_buf(elementwise_dev_buf_size);
|
||||
DeviceMem scale_device_buf(elementwise_dev_buf_size);
|
||||
DeviceMem shift_device_buf(elementwise_dev_buf_size);
|
||||
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
wei_device_buf.ToDevice(weight.mData.data());
|
||||
|
||||
bias_device_buf.ToDevice(bias.mData.data());
|
||||
mean_device_buf.ToDevice(mean.mData.data());
|
||||
variance_device_buf.ToDevice(variance.mData.data());
|
||||
scale_device_buf.ToDevice(scale.mData.data());
|
||||
shift_device_buf.ToDevice(shift.mData.data());
|
||||
|
||||
if constexpr(ElementwiseGK)
|
||||
{
|
||||
constexpr ck::index_t spatial_offset = 3;
|
||||
d_g_n_k_wos_strides[1] = 0;
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
d_g_n_k_wos_strides[i + spatial_offset] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// run reference op
|
||||
if(do_verification)
|
||||
{
|
||||
// Run Conv and Bnorm seperatly
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
Add,
|
||||
0,
|
||||
0,
|
||||
1>{};
|
||||
|
||||
std::array<Tensor<OutDataType>, 1> d_tensors = {bias};
|
||||
auto ref_conv_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_conv_argument = ref_conv.MakeArgument(input,
|
||||
weight,
|
||||
host_output,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
Add{},
|
||||
{},
|
||||
{},
|
||||
d_tensors);
|
||||
|
||||
// init host output to zero
|
||||
host_output.SetZero();
|
||||
ref_conv_invoker.Run(ref_conv_argument);
|
||||
ref_bnorm_clamp_infer<NDimSpatial>(
|
||||
host_output, host_output, mean, variance, scale, shift, floor, ceil, epsilon);
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device op instances
|
||||
bool pass = true;
|
||||
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
|
||||
// workspace_sz will be equal to 0 for other layout than NGCHW
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init output to zero before profiling next kernel
|
||||
out_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
float avg_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(device_output, host_output);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "host_output : ", host_output.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "device_output: ", device_output.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<OutLayout, OutLayout, OutLayout, OutLayout, OutLayout>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<OutDataType, OutDataType, OutDataType, OutDataType, OutDataType>,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
{bias_device_buf.GetDeviceBuffer(),
|
||||
mean_device_buf.GetDeviceBuffer(),
|
||||
variance_device_buf.GetDeviceBuffer(),
|
||||
scale_device_buf.GetDeviceBuffer(),
|
||||
shift_device_buf.GetDeviceBuffer()},
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
{e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_lengths},
|
||||
{d_g_n_k_wos_strides,
|
||||
d_g_n_k_wos_strides,
|
||||
d_g_n_k_wos_strides,
|
||||
d_g_n_k_wos_strides,
|
||||
d_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
run_impl(op_ptr, argument_ptr);
|
||||
}
|
||||
|
||||
std::cout << "Best configuration parameters:"
|
||||
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
|
||||
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -1,4 +1,10 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_bias_bnorm_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
|
||||
|
||||
add_gtest_executable(test_grouped_convnd_fwd_gk_bias_bnorm_clamp test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_gk_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
|
||||
|
||||
add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using DataType = std::tuple_element_t<0, Tuple>;
|
||||
using InLayout = std::tuple_element_t<1, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<2, Tuple>;
|
||||
using OutLayout = std::tuple_element_t<3, Tuple>;
|
||||
using IndexType = ck::index_t;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
IndexType,
|
||||
false /*BiasGK*/>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<float, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>,
|
||||
std::tuple<float, NDHWGC, GKZYXC, NDHWGK>,
|
||||
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd2d : public TestGroupedConvndFwd<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->template Run<3>();
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using DataType = std::tuple_element_t<0, Tuple>;
|
||||
using InLayout = std::tuple_element_t<1, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<2, Tuple>;
|
||||
using OutLayout = std::tuple_element_t<3, Tuple>;
|
||||
using IndexType = ck::index_t;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass &&
|
||||
ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
IndexType,
|
||||
true /*ElementwiseGK*/>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<float, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>,
|
||||
std::tuple<float, NDHWGC, GKZYXC, NDHWGK>,
|
||||
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd2d : public TestGroupedConvndFwd<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd3d : public TestGroupedConvndFwd<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->template Run<3>();
|
||||
}
|
||||
Reference in New Issue
Block a user