diff --git a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp index 7f289c1938..4e79db91c4 100644 --- a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp +++ b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp @@ -180,6 +180,10 @@ int main(int argc, char* argv[]) out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + // reset input to zero + in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1{0}); + in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data()); + // do GEMM auto conv = DeviceConvBwdDataInstance{}; auto invoker = conv.MakeInvoker(); diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 185b96626b..27d7e0882a 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -459,6 +459,16 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda) { + // check slice is valid + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( N, K, diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index 72cc021643..3d7e3d3b4b 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace device_conv2d_bwd_data_instance { -using BF16 = ushort; +using BF16 = ck::bhalf_t; using F32 = float; template diff --git a/profiler/include/profile_conv_bwd_data_impl.hpp b/profiler/include/profile_conv_bwd_data_impl.hpp index 019020c2ac..6f291c4327 100644 --- a/profiler/include/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profile_conv_bwd_data_impl.hpp @@ -11,7 +11,7 @@ using F16 = ck::half_t; using F32 = float; -using BF16 = ushort; +using BF16 = ck::bhalf_t; using INT8 = int8_t; namespace ck { namespace tensor_operation { @@ -172,9 +172,9 @@ void profile_conv_bwd_data_impl(int do_verification, ck::tensor_operation::device::device_conv2d_bwd_data_instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); } - else if constexpr(ck::is_same_v, ushort> && - ck::is_same_v, ushort> && - ck::is_same_v, ushort>) + else if constexpr(ck::is_same_v, ck::bhalf_t> && + ck::is_same_v, ck::bhalf_t> && + ck::is_same_v, ck::bhalf_t>) { ck::tensor_operation::device::device_conv2d_bwd_data_instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); diff --git a/test/conv2d_bwd_data/conv2d_bwd_data.cpp b/test/conv2d_bwd_data/conv2d_bwd_data.cpp index 0d26596396..e3caa52bef 100644 --- a/test/conv2d_bwd_data/conv2d_bwd_data.cpp +++ b/test/conv2d_bwd_data/conv2d_bwd_data.cpp @@ -182,8 +182,8 @@ int main(int argc, char* argv[]) out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - - in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1{5}); + // reset input to zero + in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1{0}); in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data()); // get host result @@ -225,9 +225,9 @@ int main(int argc, char* argv[]) ck::tensor_operation::device::device_conv2d_bwd_data_instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); } - else if constexpr(ck::is_same_v, ushort> && - ck::is_same_v, ushort> && - ck::is_same_v, ushort>) + else if constexpr(ck::is_same_v, ck::bhalf_t> && + ck::is_same_v, ck::bhalf_t> && + ck::is_same_v, ck::bhalf_t>) { ck::tensor_operation::device::device_conv2d_bwd_data_instance:: add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);