mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Revise layout of group convolution (#675)
* [What] Remove pure conv int8 instance
[Why] We will never use pure int8 conv in AI, use int8 quantization instead
* Change layout
* Share the kernel parameter
* Support more type of NHWGC for group conv
* Revise client example of conv 2d, use NHWGC layout
* Add instance to cmake
* Revise layout of group conv quantization instance
* Revise layout of external api of group conv quantization
* Revise layout of group conv quantization client example
* Fix clang format
* Add comment to describe meaning of each parameter
[ROCm/composable_kernel commit: 3eecbfb6ec]
This commit is contained in:
@@ -19,9 +19,9 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
using GK = ck::tensor_layout::convolution::G_K;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Relu = ck::tensor_operation::element_wise::Relu;
|
||||
|
||||
@@ -9,10 +9,10 @@ namespace device {
|
||||
namespace instance {
|
||||
void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -23,19 +23,28 @@ void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
|
||||
{
|
||||
// dl
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -44,10 +53,10 @@ void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -58,19 +67,28 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
|
||||
{
|
||||
// dl
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Relu_Mul2_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Relu_Mul2_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Relu_Mul2_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -79,10 +97,10 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -93,19 +111,28 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
{
|
||||
// dl
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
|
||||
@@ -9,10 +9,10 @@ namespace device {
|
||||
namespace instance {
|
||||
void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -22,19 +22,28 @@ void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
|
||||
Add_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -43,10 +52,10 @@ void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -56,21 +65,30 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
|
||||
Add_Relu_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Relu_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Relu_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Relu_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -79,10 +97,10 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -92,21 +110,30 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
Add_Mul_TanH_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
|
||||
@@ -12,7 +12,10 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// clang-format off
|
||||
template <typename DsLayout,
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout,
|
||||
typename DsDatatype,
|
||||
typename OutElementOp,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
@@ -23,7 +26,7 @@ using device_grouped_conv2d_dl_int8_instances =
|
||||
// ###########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||
// ###########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||
// ###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< NDimSpatial, int8_t, int8_t, DsDatatype, int8_t, int32_t, GNHWC, GKYXC, DsLayout, GNHWK, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, DstScalarPerVector>
|
||||
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< NDimSpatial, int8_t, int8_t, DsDatatype, int8_t, int32_t, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, DstScalarPerVector>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ namespace device {
|
||||
namespace instance {
|
||||
void add_device_conv2d_dl_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
F32_Tuple,
|
||||
@@ -22,19 +22,28 @@ void add_device_conv2d_dl_perchannel_quantization_int8_instances(
|
||||
Mul2_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Mul2_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Mul2_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Mul2_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -43,10 +52,10 @@ void add_device_conv2d_dl_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
F32_Tuple,
|
||||
@@ -56,19 +65,28 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
|
||||
Relu_Mul2_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Relu_Mul2_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Relu_Mul2_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Relu_Mul2_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
|
||||
@@ -9,10 +9,10 @@ namespace device {
|
||||
namespace instance {
|
||||
void add_device_conv2d_dl_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
@@ -22,19 +22,28 @@ void add_device_conv2d_dl_perlayer_quantization_int8_instances(
|
||||
Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -43,10 +52,10 @@ void add_device_conv2d_dl_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
@@ -56,19 +65,28 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
|
||||
Relu_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Relu_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Relu_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_dl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Relu_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
|
||||
@@ -9,10 +9,10 @@ namespace device {
|
||||
namespace instance {
|
||||
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -22,19 +22,28 @@ void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
|
||||
Add_Mul2_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -43,10 +52,10 @@ void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -56,19 +65,28 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
|
||||
Add_Relu_Mul2_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Relu_Mul2_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Relu_Mul2_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Relu_Mul2_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -77,10 +95,10 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
@@ -90,19 +108,28 @@ void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
Add_Mul2_TanH_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
NHWGK,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
|
||||
@@ -9,10 +9,10 @@ namespace device {
|
||||
namespace instance {
|
||||
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -22,19 +22,28 @@ void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
|
||||
Add_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -43,10 +52,10 @@ void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -56,21 +65,30 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
|
||||
Add_Relu_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Relu_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Relu_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Relu_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -79,10 +97,10 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
@@ -92,21 +110,30 @@ void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
Add_Mul_TanH_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
|
||||
@@ -12,30 +12,33 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// clang-format off
|
||||
template <typename DsLayout,
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename DsDatatype,
|
||||
typename OutElementOp,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
index_t DstScalarPerVector>
|
||||
using device_grouped_conv2d_xdl_int8_instances =
|
||||
std::tuple <
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ namespace device {
|
||||
namespace instance {
|
||||
void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
F32_Tuple,
|
||||
@@ -22,19 +22,28 @@ void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
|
||||
Mul2_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Mul2_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Mul2_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Mul2_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -43,10 +52,10 @@ void add_device_conv2d_xdl_perchannel_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
F32_Tuple,
|
||||
@@ -56,19 +65,28 @@ void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(
|
||||
Relu_Mul2_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Relu_Mul2_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Relu_Mul2_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
NHWGK,
|
||||
F32_Tuple,
|
||||
Relu_Mul2_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
|
||||
@@ -9,10 +9,10 @@ namespace device {
|
||||
namespace instance {
|
||||
void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
@@ -22,19 +22,28 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
|
||||
Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
16>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
16>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
@@ -43,10 +52,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
|
||||
|
||||
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
@@ -56,19 +65,28 @@ void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
|
||||
Relu_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Relu_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
16>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Relu_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
16>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
|
||||
device_grouped_conv2d_xdl_int8_instances<NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
Relu_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
|
||||
Reference in New Issue
Block a user