Revise layout of group convolution (#675)

* [What] Remove pure conv int8 instance
[Why] We will never use pure int8 conv in AI, use int8 quantization instead

* Change layout

* Share the kernel parameter

* Support more type of NHWGC for group conv

* Revise client example of conv 2d, use NHWGC layout

* Add instance to cmake

* Revise layout of group conv quantization instance

* Revise layout of external api of group conv quantization

* Revise layout of group conv quantization client example

* Fix clang format

* Add comment to describe meaning of each parameter
This commit is contained in:
rocking
2023-04-24 12:40:00 +08:00
committed by GitHub
parent 903cd19ce3
commit 3eecbfb6ec
41 changed files with 1079 additions and 1222 deletions

View File

@@ -117,20 +117,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
@@ -159,20 +145,21 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
NHWGK,
BF16,
BF16,
Empty_Tuple,
int8_t,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
@@ -187,6 +174,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
@@ -385,12 +386,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
@@ -398,7 +393,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
// no instance
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
@@ -409,12 +404,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
// no instance
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
// no instance
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&

View File

@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
@@ -52,10 +52,10 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
@@ -68,10 +68,10 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
@@ -83,10 +83,10 @@ void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
@@ -99,10 +99,10 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_F32_Tuple,
@@ -154,9 +154,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
@@ -220,9 +220,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)

View File

@@ -17,14 +17,14 @@ namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
@@ -36,10 +36,10 @@ void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
@@ -51,10 +51,10 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
@@ -67,10 +67,10 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
@@ -82,10 +82,10 @@ void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
@@ -97,10 +97,10 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
I32_Tuple,
@@ -152,9 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
@@ -218,9 +218,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)

View File

@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
F32_Tuple,
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perchannel_quantization_int8_instances(
void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
F32_Tuple,
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
F32_Tuple,
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
GK_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
F32_Tuple,
@@ -119,9 +119,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)

View File

@@ -17,13 +17,13 @@ namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_conv2d_dl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
Empty_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
@@ -35,10 +35,10 @@ void add_device_conv2d_dl_perlayer_quantization_int8_instances(
void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
Empty_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
@@ -50,10 +50,10 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
Empty_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
@@ -65,10 +65,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
NHWGC,
GKYXC,
Empty_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
@@ -117,8 +117,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)