mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Conv + quantization + tanh (#645)
* Rename file. Prepare to support another activation
* Add comment for quantization
* Extract out_elementop
* Add tanh example
* Add conv + bias + tanh quantization instance
* Add missing parameter
* Refine cmake
* Add external api and client example
* Extract variable in example
* Fix the comment
---------
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 389e84a83b]
This commit is contained in:
@@ -85,6 +85,7 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>;
|
||||
// pointwise functor
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Relu = ck::tensor_operation::element_wise::Relu;
|
||||
using TanH = ck::tensor_operation::element_wise::TanH;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
@@ -102,6 +103,10 @@ template <typename Activation>
|
||||
using Add_Activation_Mul_Clamp =
|
||||
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
|
||||
|
||||
template <typename Activation>
|
||||
using Add_Mul_Activation_Mul_Clamp =
|
||||
ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp<Activation>;
|
||||
|
||||
template <typename Activation>
|
||||
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
|
||||
|
||||
@@ -109,6 +114,10 @@ template <typename Activation>
|
||||
using Add_Activation_Mul2_Clamp =
|
||||
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
|
||||
|
||||
template <typename Activation>
|
||||
using Add_Mul2_Activation_Mul_Clamp =
|
||||
ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp<Activation>;
|
||||
|
||||
template <typename DeviceOp, typename Tag = void>
|
||||
struct DeviceOperationInstanceFactory;
|
||||
|
||||
|
||||
@@ -49,6 +49,22 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
|
||||
Add_Activation_Mul2_Clamp<Relu>>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
@@ -80,6 +96,23 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
|
||||
Add_Activation_Mul2_Clamp<Relu>>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add_Mul2_Activation_Mul_Clamp<TanH>>>>&
|
||||
instances);
|
||||
|
||||
// piecewise activation function
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
@@ -145,6 +178,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
}
|
||||
};
|
||||
|
||||
// non-piecewise activation function
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename DsDataType,
|
||||
typename OutDataType,
|
||||
typename Activation>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Add_Mul2_Activation_Mul_Clamp<Activation>>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Add_Mul2_Activation_Mul_Clamp<Activation>>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
if constexpr(is_same_v<Activation, TanH>)
|
||||
{
|
||||
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
|
||||
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -49,6 +49,21 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
|
||||
Add_Activation_Mul_Clamp<Relu>>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
@@ -80,6 +95,22 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
|
||||
Add_Activation_Mul_Clamp<Relu>>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add_Mul_Activation_Mul_Clamp<TanH>>>>&
|
||||
instances);
|
||||
|
||||
// piecewise activation function
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
@@ -145,6 +176,67 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
}
|
||||
};
|
||||
|
||||
// non-piecewise activation function
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename DsLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename DsDataType,
|
||||
typename OutDataType,
|
||||
typename Activation>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Add_Mul_Activation_Mul_Clamp<Activation>>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Add_Mul_Activation_Mul_Clamp<Activation>>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<DsDataType, I32_Tuple> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
if constexpr(is_same_v<Activation, TanH>)
|
||||
{
|
||||
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
|
||||
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
Reference in New Issue
Block a user