From e00ef088097ccc33abb7b18a554e22e9ab460081 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Wed, 7 Jan 2026 09:40:10 -0500 Subject: [PATCH] test_convnd_bwd_data --- .../profiler/profile_conv_bwd_data_impl.hpp | 56 +++++++++++++++++-- test/convnd_bwd_data/convnd_bwd_data_xdl.cpp | 2 +- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index a0f9b9ac25..bf5ffcb5d2 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -17,6 +17,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" namespace ck { namespace profiler { @@ -129,7 +130,10 @@ bool profile_conv_bwd_data_impl(int do_verification, out_device_buf.ToDevice(output.mData.data()); wei_device_buf.ToDevice(weight.mData.data()); - if(do_verification) + // profile device Conv instances + bool pass = true; + + if(do_verification == 1) { auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData gpu_ref_input(in_g_n_c_wis_desc); + if(do_verification == 2) + { + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * + input_device_result.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization + + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(gpu_ref_input.mData.data()); + } + using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData>, DataType, DataType, - DataType>(true, // do_verification + DataType>(2, // do_verification: 2 = GPU reference 1, // init_method integer value false, // do_log false, // time_kernel