test_grouped_convnd_fwd_bias_clamp

This commit is contained in:
Graner, Johannes
2026-01-08 08:04:37 -05:00
parent 70049434d2
commit 32988018d7
2 changed files with 73 additions and 2 deletions

View File

@@ -21,6 +21,7 @@
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
namespace ck {
namespace profiler {
@@ -156,8 +157,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
bias_device_buf.ToDevice(bias.mData.data());
// run reference op
if(do_verification)
if(do_verification == 1)
{
// CPU reference
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
@@ -190,6 +192,75 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
else if(do_verification == 2)
{
// GPU reference
std::vector<ck::index_t> d_lengths_vec(NDimSpatial + 3);
std::vector<ck::index_t> d_strides_vec(NDimSpatial + 3);
d_lengths_vec[0] = conv_param.G_;
d_lengths_vec[1] = conv_param.N_;
d_lengths_vec[2] = conv_param.K_;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
d_lengths_vec[3 + i] = static_cast<ck::index_t>(conv_param.output_spatial_lengths_[i]);
}
if constexpr(BiasGK)
{
// For GK bias layout: G*K, zero strides for N and spatial dimensions
d_strides_vec[0] = K;
d_strides_vec[1] = 0;
d_strides_vec[2] = 1;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
d_strides_vec[3 + i] = 0;
}
}
else
{
// Full GNKHW layout - same as output
ck::ranges::copy(out_g_n_k_wos_desc.GetStrides(), d_strides_vec.begin());
}
std::array<const OutDataType*, 1> d_ptrs = {
reinterpret_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer())};
std::array<std::vector<ck::index_t>, 1> d_lengths = {d_lengths_vec};
std::array<std::vector<ck::index_t>, 1> d_strides = {d_strides_vec};
std::array<const InDataType*, 1> in_ptrs = {
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer())};
std::array<const WeiDataType*, 1> wei_ptrs = {
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer())};
ck::ref::naive_conv_fwd_multi_abd<0,
0,
1,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
OutDataType>( // Explicitly specify TD = OutDataType
in_ptrs,
wei_ptrs,
d_ptrs,
reinterpret_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
out_element_op);
HIP_CHECK_ERROR(hipDeviceSynchronize());
out_device_buf.FromDevice(host_output.mData.data());
}
std::string best_op_name;
float best_avg_time = 0;

View File

@@ -49,7 +49,7 @@ class TestGroupedConvndFwd : public ::testing::Test
DataType,
IndexType,
false /*BiasGK*/>(
true, // do_verification
2, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel