Fix conv2d bwd data bug when filter is 1x1 and stride = 2 (#132)

* fix bwd data filter1strid2 bug

* fichangeshort to ck::bhalf_t

* reset input to zero

Co-authored-by: ltqin <letaoqin@amd.com>
This commit is contained in:
ltqin
2022-03-21 23:53:23 +08:00
committed by GitHub
parent 9a17e7fbfd
commit b51808d7a5
5 changed files with 24 additions and 10 deletions

View File

@@ -182,8 +182,8 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{5});
// reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
// get host result
@@ -225,9 +225,9 @@ int main(int argc, char* argv[])
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ushort> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ushort> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ushort>)
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::bhalf_t>)
{
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);