diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index 50cd58eec3..2a282edbc8 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -21,6 +21,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" namespace ck { namespace profiler { @@ -156,8 +157,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, bias_device_buf.ToDevice(bias.mData.data()); // run reference op - if(do_verification) + if(do_verification == 1) { + // CPU reference auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); + + d_lengths_vec[0] = conv_param.G_; + d_lengths_vec[1] = conv_param.N_; + d_lengths_vec[2] = conv_param.K_; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + } + + if constexpr(BiasGK) + { + // For GK bias layout: G*K, zero strides for N and spatial dimensions + d_strides_vec[0] = K; + d_strides_vec[1] = 0; + d_strides_vec[2] = 1; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_strides_vec[3 + i] = 0; + } + } + else + { + // Full GNKHW layout - same as output + ck::ranges::copy(out_g_n_k_wos_desc.GetStrides(), d_strides_vec.begin()); + } + + std::array d_ptrs = { + reinterpret_cast(bias_device_buf.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer())}; + + ck::ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + OutDataType>( // Explicitly specify TD = OutDataType + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); + } std::string best_op_name; float best_avg_time = 0; diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp index d1706d4cec..68a8b016e3 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp @@ -49,7 +49,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, false /*BiasGK*/>( - true, // do_verification + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel