NHWC conv 2d: fwd bfp16/int8, Device level tuning and host API (#73)

* add fwd bf16 conv

* change tunning parametor

* add int8 for conv fwd

* remove comments

* change tunning parametor for int8

* change init int8 example

* add test for conv2d fwd

* change device operation file pos because merge develop

* fwd int8 use reference

* test_conv_fwd use reference

* add braket for if statement

* rename fwd example name

* remove StaticBufferOfVectorTypeV2

* tweak example

Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>

[ROCm/composable_kernel commit: 880fbee957]
This commit is contained in:
ltqin
2022-02-12 10:06:40 +08:00
committed by GitHub
parent a76cc6deed
commit 32c128bcc5
16 changed files with 960 additions and 130 deletions

View File

@@ -25,6 +25,9 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceCo
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
@@ -171,6 +174,20 @@ void profile_conv_fwd_impl(int do_verification,
ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ushort> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ushort> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ushort>)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, int8_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, int8_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, int8_t>)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
}
if(conv_ptrs.size() <= 0)
{