test_convnd_fwd

This commit is contained in:
Graner, Johannes
2026-01-07 09:31:30 -05:00
parent d7497d2694
commit c7da77d51b
2 changed files with 42 additions and 5 deletions

View File

@@ -21,6 +21,7 @@
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
namespace ck {
namespace profiler {
@@ -107,8 +108,11 @@ bool profile_conv_fwd_impl(int do_verification,
in_device_buf.ToDevice(input.mData.data());
wei_device_buf.ToDevice(weight.mData.data());
// profile device op instances
bool pass = true;
// run reference op
if(do_verification)
if(do_verification == 1)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
@@ -135,6 +139,24 @@ bool profile_conv_fwd_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
// GPU reference (compute once, compare in kernel loop)
Tensor<OutDataType> gpu_ref_output(out_g_n_k_wos_desc);
if(do_verification == 2)
{
DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
ck::ref::naive_conv_fwd<InLayout, WeiLayout, OutLayout>(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(gpu_ref_out_dev.GetDeviceBuffer()),
conv_param,
in_element_op,
wei_element_op,
out_element_op);
hip_check_error(hipDeviceSynchronize());
gpu_ref_out_dev.FromDevice(gpu_ref_output.mData.data());
}
using DeviceOp = ck::tensor_operation::device::DeviceConvFwd<NDimSpatial,
InLayout,
@@ -158,8 +180,6 @@ bool profile_conv_fwd_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
bool pass = true;
for(auto& op_ptr : op_ptrs)
{
@@ -217,7 +237,7 @@ bool profile_conv_fwd_impl(int do_verification,
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
if(do_verification == 1)
{
out_device_buf.FromDevice(device_output.mData.data());
@@ -233,6 +253,23 @@ bool profile_conv_fwd_impl(int do_verification,
<< std::endl;
}
}
else if(do_verification == 2)
{
out_device_buf.FromDevice(device_output.mData.data());
pass = pass & ck::utils::check_err(device_output, gpu_ref_output);
if(do_log)
{
LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "gpu_ref_output : ", gpu_ref_output.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "device_output: ", device_output.mData, ",")
<< std::endl;
}
}
}
else
{

View File

@@ -47,7 +47,7 @@ class TestConvndFwd : public ::testing::Test
ck::tensor_layout::convolution::NDHWK>>,
DataType,
DataType,
DataType>(true, // do_verification
DataType>(2, // do_verification: 2 = GPU reference
1, // init_method integer value
false, // do_log
false, // time_kernel