diff --git a/composable_kernel/include/tensor_operation/element_wise_operation.hpp b/composable_kernel/include/tensor_operation/element_wise_operation.hpp index 5f717b157d..1e45a5b7eb 100644 --- a/composable_kernel/include/tensor_operation/element_wise_operation.hpp +++ b/composable_kernel/include/tensor_operation/element_wise_operation.hpp @@ -1,6 +1,8 @@ #ifndef CK_ELEMENT_WISE_OPERATION_HPP #define CK_ELEMENT_WISE_OPERATION_HPP +#include "data_type.hpp" + namespace ck { namespace tensor_operation { namespace element_wise { diff --git a/device_operation/CMakeLists.txt b/device_operation/CMakeLists.txt index 440e16c2fa..5872b69b99 100644 --- a/device_operation/CMakeLists.txt +++ b/device_operation/CMakeLists.txt @@ -76,6 +76,11 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; ) +# device_conv1d_fwd_instance +set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp; +) + # device_conv2d_fwd_bias_relu_instance set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; @@ -96,16 +101,18 @@ add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_S add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) +add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) target_include_directories(device_gemm_instance SYSTEM PUBLIC $) target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $) target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $) target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $) target_include_directories(device_batched_gemm_instance SYSTEM PUBLIC $) +target_include_directories(device_conv1d_fwd_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) @@ -116,6 +123,7 @@ target_compile_features(device_gemm_bias_2d_instance PUBLIC) target_compile_features(device_gemm_bias_relu_instance PUBLIC) target_compile_features(device_gemm_bias_relu_add_instance PUBLIC) target_compile_features(device_batched_gemm_instance PUBLIC) +target_compile_features(device_conv1d_fwd_instance PUBLIC) target_compile_features(device_conv2d_fwd_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) @@ -126,6 +134,7 @@ set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDE set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -136,7 +145,8 @@ install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) diff --git a/device_operation/include/conv_utils.hpp b/device_operation/include/conv_utils.hpp new file mode 100644 index 0000000000..9aa616633e --- /dev/null +++ b/device_operation/include/conv_utils.hpp @@ -0,0 +1,198 @@ +#ifndef CONV_UTILS_HPP +#define CONV_UTILS_HPP + +#include +#include +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "host_tensor.hpp" +#include "tensor_layout.hpp" + +namespace ck { +namespace conv_util { + +/** + * @brief Calculate number of FLOPs for Convolution + * + * @param[in] N Batch size. + * @param[in] C Number of input channels. + * @param[in] K Number of output channels. + * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. + * @param[in] output_spatial_lengths Convolution output spatial dimensions + * lengths. + * + * @return The number of flops. + */ +std::size_t GetFlops(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // 2 * N * K * * C * + return static_cast(2) * N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies()) * + C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies()); +} + +/** + * @brief Calculate number of bytes read/write by convolution algorithm. + * + * @param[in] N Batch size. + * @param[in] C Number of input channels. + * @param[in] K Number of output channels. + * @param[in] input_spatial_lengths Input spatial dimensions lengths. + * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. + * @param[in] output_spatial_lengths Output spatial dimensions lengths + * + * @tparam InDataType Input tensor data type. + * @tparam WeiDataType Weights tensor data type. + * @tparam OutDataType Output tensor data type. + * + * @return The number of used bytes. + */ +template +std::size_t GetBtype(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // sizeof(InDataType) * (N * C * ) + + // sizeof(WeiDataType) * (K * C * ) + + // sizeof(OutDataType) * (N * K * ); + return sizeof(InDataType) * (N * C * + std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + static_cast(1), + std::multiplies())) + + sizeof(WeiDataType) * (K * C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies())) + + sizeof(OutDataType) * (N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies())); +} + +struct ConvParams +{ + ConvParams() + : num_dim_spatial(2), + N(128), + K(256), + C(192), + filter_spatial_lengths(2, 3), + input_spatial_lengths(2, 71), + conv_filter_strides(2, 2), + conv_filter_dilations(2, 1), + input_left_pads(2, 1), + input_right_pads(2, 1) + { + } + + ck::index_t num_dim_spatial; + ck::index_t N; + ck::index_t K; + ck::index_t C; + + std::vector filter_spatial_lengths; + std::vector input_spatial_lengths; + + std::vector conv_filter_strides; + std::vector conv_filter_dilations; + + std::vector input_left_pads; + std::vector input_right_pads; + + std::vector GetOutputSpatialLengths() const + { + std::vector out_spatial_len(num_dim_spatial, 0); + for(ck::index_t i = 0; i < num_dim_spatial; ++i) + { + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::index_t idx_eff = + (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1; + out_spatial_len[i] = + (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / + conv_filter_strides[i] + + 1; + } + return out_spatial_len; + } +}; + +/** + * @brief Gets the host tensor descriptor. + * + * @param[in] dims The tensor dimensions lengths. Always in NCHW format. + * @param[in] layout The tensor data layout. + * + * @tparam TensorLayout Layout type. + * + * @return The host tensor descriptor object. + */ +template +HostTensorDescriptor GetHostTensorDescriptor(const std::vector& dims, + const TensorLayout& layout) +{ + std::size_t C = dims[1]; + // 1D + if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor(dims, std::vector({C * dims[2], dims[2], 1})); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor(dims, std::vector({C * dims[2], 1, C})); + } + // 2D + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1}); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3], 1, dims[3] * C, C}); + } + + std::stringstream err_msg; + err_msg << "Unsupported data layout provided: " << layout << "!"; + throw std::runtime_error(err_msg.str()); +} + +} // namespace conv_util +} // namespace ck + +#endif diff --git a/device_operation/include/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/device_operation/include/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 63a832e150..ffa1815ab7 100644 --- a/device_operation/include/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/device_operation/include/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -215,7 +215,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Pad0) { - static_assert(ConvForwardSpecialization == -1, "Not implemented!"); } else diff --git a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..2997652c82 --- /dev/null +++ b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,865 @@ +#ifndef DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP + +#include +#include +#include +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Inputs with up to 3 spatial dimentions +// @li Input tensor in NHWC data format +// @li Weight tensor in KYXC data format +// @li Output tensor in NHWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template +struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd +{ + using DeviceOp = DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = NumDimSpatial; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto GetWeightTensorDescriptor(ck::index_t gemm_n, ck::index_t gemm_k) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k)); + + // wei_gemmk0_gemmn_gemmk1_grid_desc + return transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_n)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto + GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n, ck::index_t gemm_m_pad) + { + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n)); + + // out_gemmm_gemmn_grid_desc + return transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(gemm_n)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<2>{}, Sequence<0, 1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t X = filter_spatial_lengths[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + static index_t GetGemmMRaw(ck::index_t N, + const std::vector& output_spatial_lengths) + { + return N * std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + 1, + std::multiplies()); + } + + static index_t GetGemmK(ck::index_t C, const std::vector& filter_spatial_lengths) + { + return C * std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + 1, + std::multiplies()); + } + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t GemmMRaw = GetGemmMRaw(N, output_spatial_lengths); + const index_t GemmN = K; + const index_t GemmK = GetGemmK(C, filter_spatial_lengths); + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + // C = A^T*B + // A: + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + GetInputTensorDescriptor(N, + C, + GemmMRaw, + GemmK, + GemmMPad, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + // B: + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK); + // C: + const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmN, GemmMPad); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(ck::index_t i = 0; i < NumDimSpatial; ++i) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(ck::index_t i = 0; i < NumDimSpatial; ++i) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 && + arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv" << std::to_string(NumDimSpatial) + << "DFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/tensor_layout.hpp b/device_operation/include/tensor_layout.hpp index 4fae86d875..4904f004a0 100644 --- a/device_operation/include/tensor_layout.hpp +++ b/device_operation/include/tensor_layout.hpp @@ -12,37 +12,77 @@ namespace gemm { struct RowMajor : public BaseTensorLayout { + static constexpr const char* name = "RowMajor"; }; struct ColumnMajor : public BaseTensorLayout { + static constexpr const char* name = "ColumnMajor"; }; } // namespace gemm namespace convolution { +// 1D Conv +struct NWC : public BaseTensorLayout +{ + static constexpr const char* name = "NWC"; +}; + +struct KXC : public BaseTensorLayout +{ + static constexpr const char* name = "KXC"; +}; + +struct NWK : public BaseTensorLayout +{ + static constexpr const char* name = "NWK"; +}; + +struct NCW : public BaseTensorLayout +{ + static constexpr const char* name = "NCW"; +}; + +struct KCX : public BaseTensorLayout +{ + static constexpr const char* name = "KCX"; +}; + +struct NKW : public BaseTensorLayout +{ + static constexpr const char* name = "NKW"; +}; + +// 2D Conv struct NHWC : public BaseTensorLayout { + static constexpr const char* name = "NHWC"; }; struct KYXC : public BaseTensorLayout { + static constexpr const char* name = "KYXC"; }; struct NHWK : public BaseTensorLayout { + static constexpr const char* name = "NHWK"; }; struct NCHW : public BaseTensorLayout { + static constexpr const char* name = "NCHW"; }; struct KCYX : public BaseTensorLayout { + static constexpr const char* name = "KCYX"; }; struct NKHW : public BaseTensorLayout { + static constexpr const char* name = "NKHW"; }; struct NDHWC : public BaseTensorLayout @@ -59,6 +99,15 @@ struct NDHWK : public BaseTensorLayout } // namespace convolution +template < + typename Layout, + typename std::enable_if::value, bool>::type = false> +std::ostream& operator<<(std::ostream& os, const Layout&) +{ + os << Layout::name; + return os; +} + } // namespace tensor_layout } // namespace ck #endif diff --git a/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp new file mode 100644 index 0000000000..8702d18596 --- /dev/null +++ b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +//------------------------------------------------------------------------------ +// Conv1D +//------------------------------------------------------------------------------ + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index 402d65a6e0..69ff391968 100644 --- a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -28,67 +28,67 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> // clang-format on >; diff --git a/example/11_convnd_fwd_xdl/README.md b/example/11_convnd_fwd_xdl/README.md new file mode 100644 index 0000000000..d85a409165 --- /dev/null +++ b/example/11_convnd_fwd_xdl/README.md @@ -0,0 +1,65 @@ +# Instructions for ```convnd_fwd_xdl``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```convnd_fwd_xdl``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j convnd_fwd_xdl +``` + +## Run ```convnd_fwd_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: N spatial dimensions (default 2) +#Following arguments (depending on number of spatial dims): +# N, K, C, +# , (ie Y, X for 2D) +# , (ie Hi, Wi for 2D) +# , (ie Sy, Sx for 2D) +# , (ie Dy, Dx for 2D) +# , (ie LeftPy, LeftPx for 2D) +# , (ie RightPy, RightPx for 2D) +./example/convnd_fwd_xdl 0 1 100 +``` + +Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32) +``` +input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_{432, 165888, 4} +arg.b_grid_desc_k0_n_k1_{432, 256, 4} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 100 times... +Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s +``` diff --git a/example/11_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/11_convnd_fwd_xdl/convnd_fwd_xdl.cpp new file mode 100644 index 0000000000..614303a188 --- /dev/null +++ b/example/11_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -0,0 +1,379 @@ +#include +#include +#include +#include + +#include "config.hpp" +#include "conv_utils.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +using DeviceConvFwdBasePtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 4, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +// clang-format on + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +void PrintUseMsg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + PrintUseMsg(); + exit(0); + } + + ck::conv_util::ConvParams params; + int arg_idx = 5; + + params.num_dim_spatial = num_dim_spatial; + params.N = std::stoi(argv[arg_idx++]); + params.K = std::stoi(argv[arg_idx++]); + params.C = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 2: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWK{}); + } + case 1: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWK{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 2: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::KYXC{}); + } + case 1: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::KXC{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 2: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); + } + case 1: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + int num_dim_spatial = 2; + + ck::conv_util::ConvParams params; + + if(argc >= 5) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + } + + if(argc >= 6) + { + params = ParseConvParams(num_dim_spatial, argc, argv); + } + + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); + Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // do GEMM + auto conv = GetConvInstance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker->Run(argument.get(), nrepeat); + + std::size_t flop = ck::conv_util::GetFlops( + params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = + ck::conv_util::GetBtype(params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output]( + const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + check_error(host_output, device_output); + }; + + switch(num_dim_spatial) + { + case 2: { + auto ref_conv = ReferenceConvNDFwdInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvNDFwdInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 8377cf7679..7e6daa7ad6 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -23,6 +23,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp) set(CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp) set(CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp) +set(CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp) add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE}) @@ -34,6 +35,7 @@ add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_AT add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE}) add_executable(conv2d_fwd_xdl_int8 ${CONV2D_FWD_XDL_INT8_SOURCE}) add_executable(conv3d_fwd_xdl ${CONV3D_FWD_XDL_SOURCE}) +add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) target_link_libraries(gemm_xdl PRIVATE host_tensor) target_link_libraries(gemm_xdl_bias_relu PRIVATE host_tensor) @@ -45,4 +47,4 @@ target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor) target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_int8 PRIVATE host_tensor) target_link_libraries(conv3d_fwd_xdl PRIVATE host_tensor) - +target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor) diff --git a/reference_operation/include/reference_conv_fwd.hpp b/reference_operation/include/reference_conv_fwd.hpp index 6bcd7d28e0..0bba22423f 100644 --- a/reference_operation/include/reference_conv_fwd.hpp +++ b/reference_operation/include/reference_conv_fwd.hpp @@ -2,6 +2,7 @@ #define REFERENCE_CONV_FWD_HPP #include +#include #include #include "device_base.hpp" #include "host_tensor.hpp" @@ -10,21 +11,38 @@ namespace ck { namespace tensor_operation { namespace host { -// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] +// +// @brief Reference implementation for forward convolution. +// +// @paragraph Supported tensor layouts. Input tensor supports NCHiWi data layout. +// Weights tensor supports KCYX data layout. Output tensor supports +// NKHoWo data layout. +// +// @tparam InDataType Input tensor data type. +// @tparam WeiDataType Weights tensor data type. +// @tparam OutDataType Output tensor data type. +// @tparam InElementwiseOperation Functor for input tensor elementwise +// operation. +// @tparam WeiElementwiseOperation Functor for weights tensor elementwise +// operation. +// @tparam NumDimSpatial Number of spatial dimensions. +// template + typename OutElementwiseOperation, + ck::index_t NumDimSpatial = 2, + typename std::enable_if= 1 && NumDimSpatial <= 3, bool>::type = false> struct ReferenceConvFwd : public device::BaseOperator { // Argument struct Argument : public device::BaseArgument { - Argument(const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, + Argument(const Tensor& input, + const Tensor& weight, + Tensor& output, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, @@ -32,9 +50,9 @@ struct ReferenceConvFwd : public device::BaseOperator InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) - : in_n_c_hi_wi_{in_n_c_hi_wi}, - wei_k_c_y_x_{wei_k_c_y_x}, - out_n_k_ho_wo_{out_n_k_ho_wo}, + : input_{input}, + weight_{weight}, + output_{output}, conv_strides_{conv_filter_strides}, conv_dilations_{conv_filter_dilations}, in_left_pads_{input_left_pads}, @@ -45,9 +63,9 @@ struct ReferenceConvFwd : public device::BaseOperator { } - const Tensor& in_n_c_hi_wi_; - const Tensor& wei_k_c_y_x_; - Tensor& out_n_k_ho_wo_; + const Tensor& input_; + const Tensor& weight_; + Tensor& output_; std::vector conv_strides_; std::vector conv_dilations_; @@ -59,58 +77,98 @@ struct ReferenceConvFwd : public device::BaseOperator OutElementwiseOperation out_element_op_; }; - // Invoker struct Invoker : public device::BaseInvoker { using Argument = ReferenceConvFwd::Argument; float Run(const Argument& arg) { - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - float v_acc = 0; + if constexpr(NumDimSpatial == 1) + { + auto f_ncw = [&](auto n, auto k, auto wo) { + float v_acc = 0; - for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) { - int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - - arg.in_left_pads_[0]; - for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) { - int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - - arg.in_left_pads_[1]; - if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && - wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - + arg.in_left_pads_[0]; + if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2]) { float v_in; float v_wei; - arg.in_element_op_( - v_in, ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))); - arg.wei_element_op_( - v_wei, ck::type_convert(arg.wei_k_c_y_x_(k, c, y, x))); + arg.in_element_op_(v_in, + static_cast(arg.input_(n, c, wi))); + arg.wei_element_op_(v_wei, + static_cast(arg.weight_(k, c, x))); v_acc += v_in * v_wei; } } } - } - float v_out; + float v_out; - arg.out_element_op_(v_out, v_acc); + arg.out_element_op_(v_out, v_acc); + arg.output_(n, k, wo) = v_out; + }; - arg.out_n_k_ho_wo_(n, k, ho, wo) = ck::type_convert(v_out); - }; + make_ParallelTensorFunctor(f_ncw, + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); - make_ParallelTensorFunctor(f_nchw, - arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], - arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], - arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], - arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); + return 0; + } + else if constexpr(NumDimSpatial == 2) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; - return 0; + for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - + arg.in_left_pads_[0]; + for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - + arg.in_left_pads_[1]; + if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 && + wi < arg.input_.mDesc.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, ck::type_convert(arg.weight_(k, c, y, x))); + v_acc += v_in * v_wei; + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + arg.output_(n, k, ho, wo) = ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2], + arg.output_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } } float Run(const device::BaseArgument* p_arg, int) override @@ -127,9 +185,9 @@ struct ReferenceConvFwd : public device::BaseOperator bool IsSupportedArgument(const device::BaseArgument*) override { return true; } - static auto MakeArgument(const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, + static auto MakeArgument(const Tensor& input, + const Tensor& weight, + Tensor& output, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, @@ -138,9 +196,9 @@ struct ReferenceConvFwd : public device::BaseOperator WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) { - return Argument{in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo, + return Argument{input, + weight, + output, conv_filter_strides, conv_filter_dilations, input_left_pads, diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8dbd550227..ff483b8117 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -10,6 +10,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform ${PROJECT_SOURCE_DIR}/external/rocm/include ${PROJECT_SOURCE_DIR}/reference_operation/include + ${PROJECT_SOURCE_DIR}/test/include ) # test_magic_number_division @@ -30,3 +31,17 @@ add_executable(test_split_k ${SPLIT_K_SOURCE}) target_link_libraries(test_split_k PRIVATE host_tensor) target_link_libraries(test_split_k PRIVATE device_gemm_instance) +# test_conv_util +set(CONV_UTIL_SOURCE conv_util/main.cpp) +add_executable(test_conv_util ${CONV_UTIL_SOURCE}) +target_link_libraries(test_conv_util PRIVATE host_tensor) + +# test_reference_conv_fwd +set(REFERENCE_CONV_FWD_SOURCE reference_conv_fwd/main.cpp) +add_executable(test_reference_conv_fwd ${REFERENCE_CONV_FWD_SOURCE}) +target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor) + +# test_convnd_fwd_xdl +set(CONVND_FWD_XDL_SOURCE convnd_fwd_xdl/main.cpp) +add_executable(test_convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) +target_link_libraries(test_convnd_fwd_xdl PRIVATE host_tensor) diff --git a/test/conv_util/main.cpp b/test/conv_util/main.cpp new file mode 100644 index 0000000000..ee194f2462 --- /dev/null +++ b/test/conv_util/main.cpp @@ -0,0 +1,157 @@ +#include +#include +#include + +#include "config.hpp" +#include "conv_utils.hpp" +#include "tensor_layout.hpp" + +namespace { + +template +bool cmp_vec(const std::vector& out, const std::vector& ref, const std::string& msg) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + for(std::size_t i = 0; i < ref.size(); ++i) + { + if(out[i] != ref[i]) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] + << std::endl + << msg << std::endl; + return false; + } + } + return true; +} + +bool TestConvParams_GetOutputSpatialLengths() +{ + bool res{true}; + // -------------------------- default 2D ------------------------------------ + // input NCHW {128,192,71,71}, + // weights KCYX {256,192,3,3}, + // stride {2,2}, + // dilations {1,1}, + // padding {{1,1}, {1,1}} + ck::conv_util::ConvParams conv_params; + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, + std::vector{36, 36}, + "Error: ConvParams 2D default constructor."); + + conv_params.conv_filter_strides = std::vector{1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec( + out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}."); + + conv_params.conv_filter_strides = std::vector{2, 2}; + conv_params.input_left_pads = std::vector{2, 2}; + conv_params.input_right_pads = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, + std::vector{37, 37}, + "Error: ConvParams 2D padding left/right {2,2}."); + + conv_params.conv_filter_dilations = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec( + out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}."); + + conv_params.conv_filter_strides = std::vector{3, 3}; + conv_params.input_left_pads = std::vector{1, 1}; + conv_params.input_right_pads = std::vector{1, 1}; + conv_params.conv_filter_dilations = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, + std::vector{23, 23}, + "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."); + + // -------------------------- 1D ------------------------------------ + conv_params.num_dim_spatial = 1; + conv_params.filter_spatial_lengths = std::vector{3}; + conv_params.input_spatial_lengths = std::vector{71}; + conv_params.conv_filter_strides = std::vector{2}; + conv_params.conv_filter_dilations = std::vector{1}; + conv_params.input_left_pads = std::vector{1}; + conv_params.input_right_pads = std::vector{1}; + + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D default constructor."); + + conv_params.conv_filter_strides = std::vector{1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = + cmp_vec(out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}."); + + conv_params.conv_filter_strides = std::vector{2}; + conv_params.input_left_pads = std::vector{2}; + conv_params.input_right_pads = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, + std::vector{37}, + "Error: ConvParams 1D padding left/right {2}."); + + conv_params.conv_filter_dilations = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}."); + + conv_params.conv_filter_strides = std::vector{3}; + conv_params.input_left_pads = std::vector{1}; + conv_params.input_right_pads = std::vector{1}; + conv_params.conv_filter_dilations = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, + std::vector{23}, + "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."); + + return res; +} + +bool TestGetHostTensorDescriptor() +{ + bool res{true}; + namespace tl = ck::tensor_layout::convolution; + std::vector dims{2, 3, 4, 5}; + HostTensorDescriptor h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); + res = cmp_vec(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"); + res = + cmp_vec(h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!"); + + h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCHW{}); + res = cmp_vec(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"); + res = + cmp_vec(h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"); + + dims = std::vector{2, 3, 4}; + h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); + res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); + res = cmp_vec(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); + + h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCW{}); + res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); + res = cmp_vec(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"); + + return res; +} + +} // namespace + +int main(void) +{ + bool res = TestConvParams_GetOutputSpatialLengths(); + std::cout << "TestConvParams_GetOutputSpatialLengths ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + res = TestGetHostTensorDescriptor(); + std::cout << "TestGetHostTensorDescriptor ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return 0; +} diff --git a/test/convnd_fwd_xdl/main.cpp b/test/convnd_fwd_xdl/main.cpp new file mode 100644 index 0000000000..045becf32f --- /dev/null +++ b/test/convnd_fwd_xdl/main.cpp @@ -0,0 +1,262 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "conv_utils.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" +#include "test_util.hpp" + +namespace { +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + InDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + SpatialDims, // SptialDims + 64, // BlockSize + 16, // MPerBlock + 16, // NPerBlock + 4, // K0PerBlock + 1, // K1 + 16, // MPerXDL + 16, // NPerXDL + 1, // MXdlPerWave + 1, // NXdlPerWave + S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +// clang-format on + +template +auto GetHostTensors(const ck::conv_util::ConvParams& params) +{ + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); + Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + Tensor device_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + + std::generate(input.begin(), input.end(), [n = 0]() mutable { + return InDataType(n++) * InDataType(0.1f); + }); + std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); + + return std::make_tuple(input, weights, host_output, device_output); +} + +template +void RunReferenceConv(const ck::conv_util::ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); +} + +template +void RunConv(const ck::conv_util::ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + + auto conv = DeviceConvNDFwdInstance(); + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + invoker.Run(argument); + out_device_buf.FromDevice(output.mData.data()); +} + +bool TestConv2DNHWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.N = 2; + params.K = 16; + params.C = 4; + params.input_spatial_lengths = std::vector{16, 16}; + params.conv_filter_strides = std::vector{1, 1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + RunReferenceConv<2>(params, input, weights, host_output); + RunConv<2>(params, input, weights, device_output); + res = res && + test_util::check_err( + device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + + return res; +} + +bool TestConv1DNWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.num_dim_spatial = 1; + params.N = 2; + params.K = 16; + params.C = 4; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{16}; + params.conv_filter_strides = std::vector{1}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + RunReferenceConv<1>(params, input, weights, host_output); + RunConv<1>(params, input, weights, device_output); + res = res && + test_util::check_err( + device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + + return res; +} + +} // anonymous namespace + +int main() +{ + bool res{true}; + res = TestConv1DNWC(); + std::cout << "TestConv1DNWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv2DNHWC(); + std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; +} diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp new file mode 100644 index 0000000000..f779c3dd1d --- /dev/null +++ b/test/include/test_util.hpp @@ -0,0 +1,84 @@ +#ifndef TEST_UTIL_HPP +#define TEST_UTIL_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace test_util { + +template +typename std::enable_if::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg, + T rtol = static_cast(1e-5), + T atol = static_cast(1e-8)) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + T err = 0; + T max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + err = std::abs(out[i] - ref[i]); + if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << out[i] << "!=" << ref[i] << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value, bool>::type check_err( + const std::vector& out, const std::vector& ref, const std::string& msg, T = 0, T = 0) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + for(std::size_t i = 0; i < ref.size(); ++i) + { + if(out[i] != ref[i]) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] + << std::endl + << msg << std::endl; + return false; + } + } + return true; +} + +} // namespace test_util + +#endif diff --git a/test/reference_conv_fwd/main.cpp b/test/reference_conv_fwd/main.cpp new file mode 100644 index 0000000000..cc5c113f59 --- /dev/null +++ b/test/reference_conv_fwd/main.cpp @@ -0,0 +1,333 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "conv_utils.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" +#include "test_util.hpp" + +namespace { +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +template +struct FillMonotonicSeq +{ + T m_init_value{0}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::iota(first, last, m_init_value); + } +}; + +template +struct FillConstant +{ + T m_value{0}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::fill(first, last, m_value); + } +}; + +template , + typename FillWeightsOp = FillConstant> +Tensor RunReferenceConv(const ck::conv_util::ConvParams& params, + const FillInputOp& fill_input_op = FillInputOp{0}, + const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) +{ + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); + Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + + fill_input_op(input.begin(), input.end()); + fill_weights_op(weights.begin(), weights.end()); + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + return host_output; +} + +bool TestConv2DNHWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.N = 1; + params.K = 1; + params.C = 2; + params.filter_spatial_lengths = std::vector{3, 3}; + params.input_spatial_lengths = std::vector{6, 6}; + params.conv_filter_strides = std::vector{1, 1}; + params.conv_filter_dilations = std::vector{1, 1}; + params.input_left_pads = std::vector{0, 0}; + params.input_right_pads = std::vector{0, 0}; + + auto out_tensor = RunReferenceConv<2>(params); + std::vector ref_dims{1, 1, 4, 4}; + std::vector ref_data{130.5, + 148.5, + 166.5, + 184.5, + 238.5, + 256.5, + 274.5, + 292.5, + 346.5, + 364.5, + 382.5, + 400.5, + 454.5, + 472.5, + 490.5, + 508.5}; + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + + params.N = 1; + params.K = 2; + params.C = 2; + params.filter_spatial_lengths = std::vector{3, 3}; + params.input_spatial_lengths = std::vector{12, 12}; + params.conv_filter_strides = std::vector{2, 2}; + params.conv_filter_dilations = std::vector{2, 2}; + params.input_left_pads = std::vector{1, 1}; + params.input_right_pads = std::vector{1, 1}; + + out_tensor = RunReferenceConv<2>(params); + ref_dims = std::vector{1, 2, 5, 5}; + ref_data = std::vector{ + 210., 210., 327., 327., 351., 351., 375., 375., 399., 399., + 459., 459., 706.5, 706.5, 742.5, 742.5, 778.5, 778.5, 814.5, 814.5, + 747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5, + 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, + 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + + return res; +} + +bool TestConv1DNWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.num_dim_spatial = 1; + params.N = 1; + params.K = 1; + params.C = 2; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{6}; + params.conv_filter_strides = std::vector{1}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{0}; + params.input_right_pads = std::vector{0}; + + auto out_tensor = RunReferenceConv<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + std::vector ref_dims{1, 1, 4}; + std::vector ref_data{7.5, 13.5, 19.5, 25.5}; + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + + params.num_dim_spatial = 1; + params.N = 1; + params.K = 2; + params.C = 2; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{12}; + params.conv_filter_strides = std::vector{2}; + params.conv_filter_dilations = std::vector{2}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + out_tensor = RunReferenceConv<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + ref_dims = std::vector{1, 2, 5}; + ref_data = std::vector{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + + params.num_dim_spatial = 1; + params.N = 2; + params.K = 16; + params.C = 4; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{16}; + params.conv_filter_strides = std::vector{1}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + auto out_tensor2 = + RunReferenceConv<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params, [](auto first, auto last) { + std::generate(first, last, [n = 0]() mutable { return float(n++) * float(0.1f); }); + }); + + ref_dims = std::vector{2, 16, 16}; + ref_data = std::vector{ + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 27., 27., 27., 27., 27., 27., 27., 27., + 27., 27., 27., 27., 27., 27., 27., 27., + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; + res = res && test_util::check_err(out_tensor2.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!"); + + return res; +} + +} // anonymous namespace + +int main(void) +{ + bool res{true}; + res = TestConv2DNHWC(); + std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv1DNWC(); + std::cout << "TestConv1DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return 0; +}