Add 2D bias clamp instances and tests

This commit is contained in:
kiefer
2025-09-01 12:22:26 +00:00
parent bc2c2fd54e
commit 4822517535
18 changed files with 1000 additions and 233 deletions

View File

@@ -228,13 +228,97 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
op_ptrs);
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
// op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
op_ptrs);
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
// op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
op_ptrs);
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
// op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
// op_ptrs);
}
#endif
}

View File

@@ -10,34 +10,33 @@ namespace instance {
#ifdef CK_ENABLE_BF16
// void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
@@ -54,95 +53,61 @@ namespace instance {
// PassThrough,
// AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
@@ -253,34 +218,33 @@ namespace instance {
#ifdef CK_ENABLE_FP16
// void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// F16,
// F16,
// Tuple<F16>,
// F16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// F16,
// F16,
// Tuple<F16>,
// F16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
@@ -297,20 +261,19 @@ namespace instance {
// PassThrough,
// AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// F16,
// F16,
// Tuple<F16>,
// F16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
@@ -326,65 +289,33 @@ void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// F16,
// F16,
// Tuple<F16>,
// F16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// F16,
// F16,
// Tuple<F16>,
// F16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// F16,
// F16,
// Tuple<F16>,
// F16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// F16,
// F16,
// Tuple<F16>,
// F16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,

View File

@@ -28,5 +28,18 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp
# WMMA CSHUFFLE V3
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
wmma/merged_groups/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/mem/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
wmma/mem/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
wmma/comp/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_16x16_instance.cpp
wmma/merged_groups/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_fp16_instance.cpp
wmma/mem/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_mem_intra_instance.cpp
wmma/mem/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_mem_inter_instance.cpp
wmma/comp/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_comp_instance.cpp
)

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -35,25 +35,27 @@ void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_
Tuple<F16>,
AddClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_f16_comp_instances<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// ConvFwd1x1P0,
// Tuple<F16>,
// AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
// add_device_operation_instances(instances,
// device_grouped_conv_fwd_wmma_f16_comp_instances<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// ConvFwd1x1S1P0,
// Tuple<F16>,
// AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Interwave,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Interwave,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Interwave,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Intrawave,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Intrawave,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Intrawave,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Interwave,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Interwave,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Interwave,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Intrawave,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Intrawave,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Intrawave,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,54 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd3x3,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,54 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_f16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_f16_instances<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd3x3,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -4,7 +4,7 @@ if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATC
target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance)
add_gtest_executable(test_grouped_convnd_fwd_gk_bias_clamp test_grouped_convnd_fwd_gk_bias_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_gk_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
target_link_libraries(test_grouped_convnd_fwd_gk_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance)
add_gtest_executable(test_grouped_convnd_fwd_clamp test_grouped_convnd_fwd_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_clamp PRIVATE utility device_grouped_conv2d_fwd_clamp_instance device_grouped_conv3d_fwd_clamp_instance)

View File

@@ -46,7 +46,7 @@ class TestGroupedConvndFwd : public ::testing::Test
true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
true, // time_kernel
param);
}
EXPECT_TRUE(pass);
@@ -79,10 +79,35 @@ TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
}

View File

@@ -46,7 +46,7 @@ class TestGroupedConvndFwd : public ::testing::Test
true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
true, // time_kernel
param);
}
EXPECT_TRUE(pass);
@@ -79,10 +79,34 @@ TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
}