Add 2D and 3D clamp instances and tests

This commit is contained in:
kiefer
2025-09-01 15:00:59 +00:00
parent 0b8de9a0dc
commit 9416c82bfa
29 changed files with 2085 additions and 3 deletions

View File

@@ -16,6 +16,10 @@
#include "grouped_convolution_forward_clamp_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_forward_clamp_wmma_cshufflev3.inc"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -214,6 +218,107 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
// layout NHWGC/GKYXC/NHWGK
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
op_ptrs);
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
// op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
op_ptrs);
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
// op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
op_ptrs);
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
op_ptrs);
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
op_ptrs);
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};

View File

@@ -0,0 +1,418 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<>,
// NHWGK,
// BF16,
// BF16,
// Tuple<>,
// BF16,
// PassThrough,
// PassThrough,
// Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
// void
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// NDHWGC,
// GKZYXC,
// Tuple<>,
// NDHWGK,
// BF16,
// BF16,
// Tuple<>,
// BF16,
// PassThrough,
// PassThrough,
// Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<>,
// NHWGK,
// F16,
// F16,
// Tuple<>,
// F16,
// PassThrough,
// PassThrough,
// Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
// void
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// NDHWGC,
// GKZYXC,
// Tuple<>,
// NDHWGK,
// F16,
// F16,
// Tuple<>,
// F16,
// PassThrough,
// PassThrough,
// Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,4 +1,4 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_grouped_conv2d_fwd_clamp_instance
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
@@ -27,4 +27,19 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance
xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp
# WMMA CSHUFFLE V3
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
wmma/merged_groups/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/mem/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
wmma/mem/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
wmma/comp/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_instance.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_16x16_instance.cpp
wmma/merged_groups/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_fp16_instance.cpp
wmma/mem/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_mem_intra_instance.cpp
wmma/mem/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_mem_inter_instance.cpp
wmma/comp/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_fp16_comp_instance.cpp
)

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_comp_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_16x16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Interwave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Interwave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Interwave,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Intrawave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Intrawave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Intrawave,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Interwave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Interwave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Interwave,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Intrawave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Intrawave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Intrawave,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,54 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_bf16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd3x3,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_f16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_f16_instances<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd3x3,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,4 +1,4 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
@@ -23,6 +23,21 @@ set(GROUPED_CONV3D_FWD
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp
xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp
# WMMA CSHUFFLE V3
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
wmma/merged_groups/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
wmma/mem/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp
wmma/mem/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
wmma/comp/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_16x16_instance.cpp
wmma/merged_groups/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp16_instance.cpp
wmma/mem/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_mem_inter_instance.cpp
wmma/mem/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_mem_intra_instance.cpp
wmma/comp/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_fp16_comp_instance.cpp
)
add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD})

View File

@@ -0,0 +1,62 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_comp_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_comp_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_comp_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_comp_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_comp_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_comp_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_16x16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_16x16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_16x16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_16x16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_16x16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_16x16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Interwave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Interwave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Interwave,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Intrawave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Intrawave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Intrawave,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Interwave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Interwave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Interwave,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Intrawave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Intrawave,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_mem_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Intrawave,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_bf16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_bf16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd3x3,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_f16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_merged_groups_f16_instances<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd3x3,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -47,7 +47,7 @@ class TestGroupedConvndFwd : public ::testing::Test
true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
true, // time_kernel
param,
out_element_op);
}
@@ -81,19 +81,71 @@ TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
}
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}