mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Revise layout of group convolution (#675)
* [What] Remove pure conv int8 instance [Why] We will never use pure int8 conv in AI, use int8 quantization instead * Change layout * Share the kernel parameter * Support more type of NHWGC for group conv * Revise client example of conv 2d, use NHWGC layout * Add instance to cmake * Revise layout of group conv quantization instance * Revise layout of external api of group conv quantization * Revise layout of group conv quantization client example * Fix clang format * Add comment to describe meaning of each parameter
This commit is contained in:
@@ -117,20 +117,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -159,20 +145,21 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -187,6 +174,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
@@ -385,12 +386,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
|
||||
@@ -398,7 +393,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
// no instance
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
@@ -409,12 +404,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
// no instance
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
// no instance
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
|
||||
|
||||
@@ -17,14 +17,14 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -52,10 +52,10 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
|
||||
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -68,10 +68,10 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -83,10 +83,10 @@ void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -99,10 +99,10 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
|
||||
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -154,9 +154,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
|
||||
@@ -220,9 +220,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
|
||||
|
||||
@@ -17,14 +17,14 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -51,10 +51,10 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -67,10 +67,10 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -82,10 +82,10 @@ void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -97,10 +97,10 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -152,9 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
|
||||
@@ -218,9 +218,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
|
||||
|
||||
@@ -17,13 +17,13 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
void add_device_conv2d_dl_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
F32_Tuple,
|
||||
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
F32_Tuple,
|
||||
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
F32_Tuple,
|
||||
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
F32_Tuple,
|
||||
@@ -119,9 +119,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
|
||||
@@ -17,13 +17,13 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
void add_device_conv2d_dl_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
@@ -117,8 +117,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
|
||||
Reference in New Issue
Block a user