test_grouped_convnd_fwd_bilinear

This commit is contained in:
Graner, Johannes
2026-01-08 08:04:56 -05:00
parent 32988018d7
commit 2e36ef8916
2 changed files with 60 additions and 3 deletions

View File

@@ -22,6 +22,7 @@
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
namespace ck {
namespace profiler {
@@ -129,8 +130,9 @@ bool profile_grouped_conv_fwd_bilinear_impl(
wei_device_buf.ToDevice(weight.mData.data());
d_device_buf.ToDevice(d_tensor.mData.data());
if(do_verification)
if(do_verification == 1)
{
// CPU reference
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
NDimSpatial,
InDataType,
@@ -167,6 +169,61 @@ bool profile_grouped_conv_fwd_bilinear_impl(
host_output(idx) = ck::type_convert<OutDataType>(out_val);
});
}
else if(do_verification == 2)
{
// GPU reference
std::vector<ck::index_t> d_lengths_vec(NDimSpatial + 3);
std::vector<ck::index_t> d_strides_vec(NDimSpatial + 3);
d_lengths_vec[0] = conv_param.G_;
d_lengths_vec[1] = conv_param.N_;
d_lengths_vec[2] = conv_param.K_;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
d_lengths_vec[3 + i] = static_cast<ck::index_t>(conv_param.output_spatial_lengths_[i]);
}
// D tensor has same layout as output
ck::ranges::copy(d_host_tensor_descriptor.GetStrides(), d_strides_vec.begin());
std::array<const DDataType*, 1> d_ptrs = {
reinterpret_cast<const DDataType*>(d_device_buf.GetDeviceBuffer())};
std::array<std::vector<ck::index_t>, 1> d_lengths = {d_lengths_vec};
std::array<std::vector<ck::index_t>, 1> d_strides = {d_strides_vec};
std::array<const InDataType*, 1> in_ptrs = {
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer())};
std::array<const WeiDataType*, 1> wei_ptrs = {
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer())};
ck::ref::naive_conv_fwd_multi_abd<0,
0,
1,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DDataType>( // Explicitly specify D tensor type
in_ptrs,
wei_ptrs,
d_ptrs,
reinterpret_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param,
d_lengths,
d_strides,
InElementOp{},
WeiElementOp{},
bilinear_op);
HIP_CHECK_ERROR(hipDeviceSynchronize());
out_device_buf.FromDevice(host_output.mData.data());
}
std::string best_op_name;
float best_avg_time = 0;

View File

@@ -66,10 +66,10 @@ class TestGroupedConvndFwdBilinear : public ::testing::Test
OutDataType,
AComputeType,
BComputeType,
IndexType>(true, // do_verification
IndexType>(2, // do_verification
1, // init_method: integer value
false, // do_log
true, // time_kernel
false, // time_kernel
param,
bilinear_op);
}