Conv + quantization + tanh (#645)

* Rename file. Prepare to support another activation

* Add comment for quantization

* Extract out_elementop

* Add tanh example

* Add conv + bias + tanh quantization instance

* Add missing parameter

* Refine cmake

* Add external api and client example

* Extract variable in example

* Fix the comment

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>

[ROCm/composable_kernel commit: 389e84a83b]
This commit is contained in:
rocking5566
2023-03-30 03:50:23 +08:00
committed by GitHub
parent ec634a3d32
commit c8d839b5d9
31 changed files with 1252 additions and 122 deletions

View File

@@ -85,6 +85,7 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
@@ -102,6 +103,10 @@ template <typename Activation>
using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Add_Mul_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
@@ -109,6 +114,10 @@ template <typename Activation>
using Add_Activation_Mul2_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
template <typename Activation>
using Add_Mul2_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory;

View File

@@ -49,6 +49,22 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp<Relu>>>>&
instances);
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances);
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
@@ -80,6 +96,23 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
Add_Activation_Mul2_Clamp<Relu>>>>&
instances);
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
instances);
// piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
@@ -145,6 +178,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
};
// non-piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul2_Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, TanH>)
{
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
}
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -49,6 +49,21 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
Add_Activation_Mul_Clamp<Relu>>>>&
instances);
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances);
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
@@ -80,6 +95,22 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
Add_Activation_Mul_Clamp<Relu>>>>&
instances);
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
instances);
// piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
@@ -145,6 +176,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
};
// non-piecewise activation function
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Mul_Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, TanH>)
{
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
}
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation