debugging conv

This commit is contained in:
Chao Liu
2022-05-25 19:01:02 +00:00
committed by Adam Osewski
parent bc8f54299a
commit b109516455

View File

@@ -8,6 +8,7 @@
#include "conv_util.hpp"
#include "device.hpp"
#include "device_tensor.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "host_tensor.hpp"
@@ -78,6 +79,52 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::
7, // CThreadTransferSrcDstVectorDim
1>; // CThreadTransferDstScalarPerVector
#if 1
using DeviceConv2DFwdInstance = ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
256, // block_size
128, // m_per_block
64, // n_per_block
4, // k_per_block
8, // k1
32, // m_per_xdl
32, // n_per_xdl
2, // m_xdl_per_wave
1, // n_xdl_per_wave
S<4,64,1>, // thread_cluster_length
S<1,0,2>, // thread_cluster_arrange_order
S<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
true, // add_extra_dim
S<4,64,1>, // thread_cluster_length
S<1,0,2>, // thread_cluster_arrange_order
S<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
true, // add_extra_dim
1, // m_xdl_per_wave
1, // n_xdl_per_wave
S<1,1,32,1,1,8>, // m_n_block_wave_per_xdl
8 // scalar_per_vector
>;
#endif
template <ck::index_t NumDimSpatial>
using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
WeiDataType,
@@ -95,7 +142,11 @@ DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial)
return std::make_unique<DeviceConvNDFwdInstance<3>>();
}
case 2: {
#if 0
return std::make_unique<DeviceConvNDFwdInstance<2>>();
#else
return std::make_unique<DeviceConv2DFwdInstance>();
#endif
}
case 1: {
return std::make_unique<DeviceConvNDFwdInstance<1>>();
@@ -291,7 +342,7 @@ int main(int argc, char* argv[])
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << conv->GetTypeString()
<< std::endl;
if(do_verification)