mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
test_grouped_convnd_fwd_bias_clamp
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -156,8 +157,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
bias_device_buf.ToDevice(bias.mData.data());
|
||||
|
||||
// run reference op
|
||||
if(do_verification)
|
||||
if(do_verification == 1)
|
||||
{
|
||||
// CPU reference
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
@@ -190,6 +192,75 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
else if(do_verification == 2)
|
||||
{
|
||||
// GPU reference
|
||||
std::vector<ck::index_t> d_lengths_vec(NDimSpatial + 3);
|
||||
std::vector<ck::index_t> d_strides_vec(NDimSpatial + 3);
|
||||
|
||||
d_lengths_vec[0] = conv_param.G_;
|
||||
d_lengths_vec[1] = conv_param.N_;
|
||||
d_lengths_vec[2] = conv_param.K_;
|
||||
for(ck::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
d_lengths_vec[3 + i] = static_cast<ck::index_t>(conv_param.output_spatial_lengths_[i]);
|
||||
}
|
||||
|
||||
if constexpr(BiasGK)
|
||||
{
|
||||
// For GK bias layout: G*K, zero strides for N and spatial dimensions
|
||||
d_strides_vec[0] = K;
|
||||
d_strides_vec[1] = 0;
|
||||
d_strides_vec[2] = 1;
|
||||
for(ck::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
d_strides_vec[3 + i] = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Full GNKHW layout - same as output
|
||||
ck::ranges::copy(out_g_n_k_wos_desc.GetStrides(), d_strides_vec.begin());
|
||||
}
|
||||
|
||||
std::array<const OutDataType*, 1> d_ptrs = {
|
||||
reinterpret_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer())};
|
||||
std::array<std::vector<ck::index_t>, 1> d_lengths = {d_lengths_vec};
|
||||
std::array<std::vector<ck::index_t>, 1> d_strides = {d_strides_vec};
|
||||
|
||||
std::array<const InDataType*, 1> in_ptrs = {
|
||||
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer())};
|
||||
std::array<const WeiDataType*, 1> wei_ptrs = {
|
||||
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer())};
|
||||
|
||||
ck::ref::naive_conv_fwd_multi_abd<0,
|
||||
0,
|
||||
1,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
OutDataType>( // Explicitly specify TD = OutDataType
|
||||
in_ptrs,
|
||||
wei_ptrs,
|
||||
d_ptrs,
|
||||
reinterpret_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
out_device_buf.FromDevice(host_output.mData.data());
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
|
||||
@@ -49,7 +49,7 @@ class TestGroupedConvndFwd : public ::testing::Test
|
||||
DataType,
|
||||
IndexType,
|
||||
false /*BiasGK*/>(
|
||||
true, // do_verification
|
||||
2, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
|
||||
Reference in New Issue
Block a user