mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
debugging conv
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -78,6 +79,52 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::
|
||||
7, // CThreadTransferSrcDstVectorDim
|
||||
1>; // CThreadTransferDstScalarPerVector
|
||||
|
||||
#if 1
|
||||
using DeviceConv2DFwdInstance = ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
|
||||
256, // block_size
|
||||
128, // m_per_block
|
||||
64, // n_per_block
|
||||
4, // k_per_block
|
||||
8, // k1
|
||||
32, // m_per_xdl
|
||||
32, // n_per_xdl
|
||||
2, // m_xdl_per_wave
|
||||
1, // n_xdl_per_wave
|
||||
|
||||
|
||||
S<4,64,1>, // thread_cluster_length
|
||||
S<1,0,2>, // thread_cluster_arrange_order
|
||||
S<1,0,2>, // src_access_order
|
||||
2, // src_vector_dim
|
||||
8, // src_scalar_per_vector
|
||||
8, // dst_scalar_per_vector
|
||||
true, // add_extra_dim
|
||||
|
||||
S<4,64,1>, // thread_cluster_length
|
||||
S<1,0,2>, // thread_cluster_arrange_order
|
||||
S<1,0,2>, // src_access_order
|
||||
2, // src_vector_dim
|
||||
8, // src_scalar_per_vector
|
||||
8, // dst_scalar_per_vector
|
||||
true, // add_extra_dim
|
||||
|
||||
1, // m_xdl_per_wave
|
||||
1, // n_xdl_per_wave
|
||||
S<1,1,32,1,1,8>, // m_n_block_wave_per_xdl
|
||||
8 // scalar_per_vector
|
||||
|
||||
>;
|
||||
#endif
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
|
||||
WeiDataType,
|
||||
@@ -95,7 +142,11 @@ DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial)
|
||||
return std::make_unique<DeviceConvNDFwdInstance<3>>();
|
||||
}
|
||||
case 2: {
|
||||
#if 0
|
||||
return std::make_unique<DeviceConvNDFwdInstance<2>>();
|
||||
#else
|
||||
return std::make_unique<DeviceConv2DFwdInstance>();
|
||||
#endif
|
||||
}
|
||||
case 1: {
|
||||
return std::make_unique<DeviceConvNDFwdInstance<1>>();
|
||||
@@ -291,7 +342,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << conv->GetTypeString()
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
|
||||
Reference in New Issue
Block a user