mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Conv + quantization + tanh (#645)
* Rename file. Prepare to support another activation
* Add comment for quantization
* Extract out_elementop
* Add tanh example
* Add conv + bias + tanh quantization instance
* Add missing parameter
* Refine cmake
* Add external api and client example
* Extract variable in example
* Fix the comment
---------
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 389e84a83b]
This commit is contained in:
@@ -25,6 +25,7 @@ using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using GK = ck::tensor_layout::convolution::G_K;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Relu = ck::tensor_operation::element_wise::Relu;
|
||||
using TanH = ck::tensor_operation::element_wise::TanH;
|
||||
|
||||
using GK_Tuple = ck::Tuple<GK>;
|
||||
using GK_GK_Tuple = ck::Tuple<GK, GK>;
|
||||
@@ -32,17 +33,25 @@ using I32_Tuple = ck::Tuple<int32_t>;
|
||||
using F32_Tuple = ck::Tuple<float>;
|
||||
using I32_F32_Tuple = ck::Tuple<int32_t, float>;
|
||||
|
||||
// perlayer
|
||||
using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>;
|
||||
using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Relu>;
|
||||
|
||||
// bias + perlayer
|
||||
using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<PassThrough>;
|
||||
using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Relu>;
|
||||
using Add_Mul_TanH_Mul_Clamp =
|
||||
ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp<TanH>;
|
||||
|
||||
// perchannel
|
||||
using Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<PassThrough>;
|
||||
using Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Relu>;
|
||||
|
||||
// bias + perchannel
|
||||
using Add_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<PassThrough>;
|
||||
using Add_Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Relu>;
|
||||
using Add_Mul2_TanH_Mul_Clamp =
|
||||
ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp<TanH>;
|
||||
|
||||
static constexpr ck::index_t NDimSpatial = 2;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
@@ -76,6 +76,42 @@ void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(
|
||||
ConvFwd1x1S1P0,
|
||||
4>{});
|
||||
}
|
||||
|
||||
void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add_Mul2_TanH_Mul_Clamp>>>& instances)
|
||||
{
|
||||
// dl
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_GK_Tuple,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
4>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -76,6 +76,43 @@ void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(
|
||||
ConvFwd1x1S1P0,
|
||||
4>{});
|
||||
}
|
||||
|
||||
void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add_Mul_TanH_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
4>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
4>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_dl_int8_instances<GK_Tuple,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
4>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -74,6 +74,41 @@ void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(
|
||||
ConvFwd1x1S1P0,
|
||||
8>{});
|
||||
}
|
||||
|
||||
void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GK_GK_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_F32_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add_Mul2_TanH_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_GK_Tuple,
|
||||
I32_F32_Tuple,
|
||||
Add_Mul2_TanH_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
8>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -76,6 +76,43 @@ void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(
|
||||
ConvFwd1x1S1P0,
|
||||
8>{});
|
||||
}
|
||||
|
||||
void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GK_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
I32_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add_Mul_TanH_Mul_Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwdDefault,
|
||||
8>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwd1x1P0,
|
||||
8>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_xdl_int8_instances<GK_Tuple,
|
||||
I32_Tuple,
|
||||
Add_Mul_TanH_Mul_Clamp,
|
||||
ConvFwd1x1S1P0,
|
||||
8>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
Reference in New Issue
Block a user