mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
NHWC conv 2d: fwd bfp16/int8, Device level tuning and host API (#73)
* add fwd bf16 conv
* change tunning parametor
* add int8 for conv fwd
* remove comments
* change tunning parametor for int8
* change init int8 example
* add test for conv2d fwd
* change device operation file pos because merge develop
* fwd int8 use reference
* test_conv_fwd use reference
* add braket for if statement
* rename fwd example name
* remove StaticBufferOfVectorTypeV2
* tweak example
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 880fbee957]
This commit is contained in:
@@ -25,6 +25,9 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceCo
|
||||
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
} // namespace device_conv2d_fwd_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
@@ -171,6 +174,20 @@ void profile_conv_fwd_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_conv2d_fwd_instance::
|
||||
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ushort> &&
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ushort> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, ushort>)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_fwd_instance::
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, int8_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, int8_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, int8_t>)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_fwd_instance::
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
|
||||
}
|
||||
|
||||
if(conv_ptrs.size() <= 0)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user