diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp index 84d013bca7..f1f985883c 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp @@ -21,7 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -55,38 +55,24 @@ class TestGroupedConvndBwdData : public ::testing::Test void RunReference(ck::utils::conv::ConvParam& conv_param, Tensor& in_host, - Tensor& wei, - Tensor& out) + DeviceMem& wei_device_buf, + DeviceMem& out_device_buf) { - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdData /*Num D Elementwise - Tensors*/ - {}; + // GPU reference + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization - auto ref_invoker = ref_conv.MakeInvoker(); + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + InElementOp{alpha}, + WeiElementOp{}, + OutElementOp{}); - auto ref_argument = ref_conv.MakeArgument(in_host, - wei, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - InElementOp{alpha}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); + ck::hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(in_host.mData.data()); } bool PerformConvDataScale(ck::utils::conv::ConvParam& conv_param, const ck::index_t split_k) @@ -121,10 +107,11 @@ class TestGroupedConvndBwdData : public ::testing::Test DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); - in_device_buf.ToDevice(in_device.mData.data()); out_device_buf.ToDevice(out.mData.data()); wei_device_buf.ToDevice(wei.mData.data()); + RunReference(conv_param, in_host, wei_device_buf, out_device_buf); + std::array out_lengths{}; std::array out_strides{}; std::array wei_lengths{}; @@ -149,8 +136,6 @@ class TestGroupedConvndBwdData : public ::testing::Test copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_right_pads_, input_right_pads); - RunReference(conv_param, in_host, wei, out); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD