test_conv_bwd_data_scale

This commit is contained in:
Graner, Johannes
2026-01-07 10:08:37 -05:00
parent e00ef08809
commit 2f83bac119

View File

@@ -21,7 +21,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
@@ -55,38 +55,24 @@ class TestGroupedConvndBwdData : public ::testing::Test
void RunReference(ck::utils::conv::ConvParam& conv_param,
Tensor<InDataType>& in_host,
Tensor<WeiDataType>& wei,
Tensor<OutDataType>& out)
DeviceMem& wei_device_buf,
DeviceMem& out_device_buf)
{
auto ref_conv =
ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
0, /*Num A Elementwise Tensors*/
0, /*Num B Elementwise Tensors*/
0,
ComputeDataType> /*Num D Elementwise
Tensors*/
{};
// GPU reference
DeviceMem gpu_ref_in_dev(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize());
gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization
auto ref_invoker = ref_conv.MakeInvoker();
ck::ref::naive_conv_bwd_data<InLayout, WeiLayout, OutLayout>(
static_cast<InDataType*>(gpu_ref_in_dev.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param,
InElementOp{alpha},
WeiElementOp{},
OutElementOp{});
auto ref_argument = ref_conv.MakeArgument(in_host,
wei,
out,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
InElementOp{alpha},
WeiElementOp{},
OutElementOp{});
ref_invoker.Run(ref_argument);
ck::hip_check_error(hipDeviceSynchronize());
gpu_ref_in_dev.FromDevice(in_host.mData.data());
}
bool PerformConvDataScale(ck::utils::conv::ConvParam& conv_param, const ck::index_t split_k)
@@ -121,10 +107,11 @@ class TestGroupedConvndBwdData : public ::testing::Test
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in_device.mData.data());
out_device_buf.ToDevice(out.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
RunReference(conv_param, in_host, wei_device_buf, out_device_buf);
std::array<ck::index_t, NDimSpatial + 3> out_lengths{};
std::array<ck::index_t, NDimSpatial + 3> out_strides{};
std::array<ck::index_t, NDimSpatial + 3> wei_lengths{};
@@ -149,8 +136,6 @@ class TestGroupedConvndBwdData : public ::testing::Test
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
RunReference(conv_param, in_host, wei, out);
using DeviceOp =
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
OutLayout,