mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Fix conv2d bwd data bug when filter is 1x1 and stride = 2 (#132)
* fix bwd data filter1strid2 bug * fichangeshort to ck::bhalf_t * reset input to zero Co-authored-by: ltqin <letaoqin@amd.com>
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ushort;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using INT8 = int8_t;
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -172,9 +172,9 @@ void profile_conv_bwd_data_impl(int do_verification,
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ushort> &&
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ushort> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, ushort>)
|
||||
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::bhalf_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::bhalf_t> &&
|
||||
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::bhalf_t>)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
|
||||
Reference in New Issue
Block a user