diff --git a/composable_kernel/include/config.hpp b/composable_kernel/include/config.hpp index 3126958b67..7f51d29715 100644 --- a/composable_kernel/include/config.hpp +++ b/composable_kernel/include/config.hpp @@ -151,6 +151,12 @@ #define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1 #endif +// workaround for verifaction failure, due to compiler regression, for conv bwd-data fp16 using some +// tuning parameter +#ifndef CK_WORKAROUND_SWDEV_325164 +#define CK_WORKAROUND_SWDEV_325164 1 +#endif + namespace ck { enum InMemoryDataOperationEnum_t diff --git a/composable_kernel/include/tensor_operation/element_wise_operation.hpp b/composable_kernel/include/tensor_operation/element_wise_operation.hpp index 1e45a5b7eb..47f5005bc6 100644 --- a/composable_kernel/include/tensor_operation/element_wise_operation.hpp +++ b/composable_kernel/include/tensor_operation/element_wise_operation.hpp @@ -1,5 +1,6 @@ #ifndef CK_ELEMENT_WISE_OPERATION_HPP #define CK_ELEMENT_WISE_OPERATION_HPP +#include "data_type.hpp" #include "data_type.hpp" diff --git a/composable_kernel/include/utility/magic_division.hpp b/composable_kernel/include/utility/magic_division.hpp index d87be11c75..6102576717 100644 --- a/composable_kernel/include/utility/magic_division.hpp +++ b/composable_kernel/include/utility/magic_division.hpp @@ -25,21 +25,30 @@ struct MagicDivision // uint32_t __host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor) { - // assert(divisior >= 1 && divisior <= INT32_MAX); - uint32_t shift = 0; - for(shift = 0; shift < 32; ++shift) + // WARNING: magic division is only applicable for division inside this range. + // You should use the return value of CalculateMagicNumbers, if division is not inside this + // range. The "else" logic below is to quiet down run-time error. + if(divisor >= 1 && divisor <= INT32_MAX) { - if((1U << shift) >= divisor) + uint32_t shift = 0; + for(shift = 0; shift < 32; ++shift) { - break; + if((1U << shift) >= divisor) + { + break; + } } + + uint64_t one = 1; + uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + // assert(multiplier <= 0xffffffffUL); + + return make_tuple(uint32_t(multiplier), shift); + } + else + { + return make_tuple(uint32_t(0), uint32_t(0)); } - - uint64_t one = 1; - uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; - // assert(multiplier <= 0xffffffffUL); - - return make_tuple(uint32_t(multiplier), shift); } __host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor) diff --git a/device_operation/CMakeLists.txt b/device_operation/CMakeLists.txt index be1fa4373a..6b5a50a640 100644 --- a/device_operation/CMakeLists.txt +++ b/device_operation/CMakeLists.txt @@ -101,16 +101,25 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; ) -add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) -add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE}) -add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) -add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) +# device_conv2d_bwd_data_instance +set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; +) + +add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) +add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) +add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) +add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE}) +add_library(device_conv2d_bwd_data_instance SHARED ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE}) target_include_directories(device_gemm_instance SYSTEM PUBLIC $) target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $) @@ -122,6 +131,7 @@ target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $) +target_include_directories(device_conv2d_bwd_data_instance SYSTEM PUBLIC $) target_compile_features(device_gemm_instance PUBLIC) target_compile_features(device_gemm_bias_2d_instance PUBLIC) @@ -133,6 +143,7 @@ target_compile_features(device_conv2d_fwd_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) +target_compile_features(device_conv2d_bwd_data_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -144,6 +155,7 @@ set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib) @@ -155,3 +167,4 @@ install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_bwd_data_instance LIBRARY DESTINATION lib) diff --git a/device_operation/include/convolution_backward_data_specialization.hpp b/device_operation/include/convolution_backward_data_specialization.hpp new file mode 100644 index 0000000000..4c1d6747c4 --- /dev/null +++ b/device_operation/include/convolution_backward_data_specialization.hpp @@ -0,0 +1,17 @@ +#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION +#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION + +namespace ck { +namespace tensor_operation { +namespace device { + +enum ConvolutionBackwardDataSpecialization_t +{ + Default, + Filter1x1Stride1Pad0, +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..185b96626b --- /dev/null +++ b/device_operation/include/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,814 @@ +#ifndef DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_bwd_data.hpp" +#include "convolution_backward_data_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdData +{ + using DeviceOp = DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) % + ABlockTransferSrcScalarPerVector == + 0); + static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) % + BBlockTransferSrcScalarPerVector == + 0); + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + index_t i_ytilda, + index_t i_xtilda) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilda); + const auto XDot = math::integer_divide_ceil(X, XTilda); + + const auto HTilda = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilda = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + const auto IHTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); + const auto IWTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + + const auto IHTildaSliceEnd = math::min( + HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildaSliceEnd = math::min( + WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; + const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilda), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilda), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilda), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilda), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilda), + make_freeze_transform(i_xtilda), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilda, HTilda), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilda, WTilda), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilda), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_freeze_transform(i_xtilda), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildaslice_wtildaslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 0, 0)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda) + { + for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda) + { + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + i_ytilda, + i_xtilda); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); + + block_2_ctile_map_container_.push_back( + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01)); + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + std::vector + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; + std::vector block_2_ctile_map_container_; + index_t M01_; + index_t N01_; + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + nrepeat = 1; + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) + << ", " + << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) + << " ) " << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); + + const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + true>; + + ave_time += launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + OutElementwiseOperation, + WeiElementwiseOperation, + InElementwiseOperation, + remove_reference_t, + false>; + + ave_time += launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_container_[i]); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i], + arg.M01_, + arg.N01_)) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_bwd_data.hpp b/device_operation/include/device_conv_bwd_data.hpp new file mode 100644 index 0000000000..1d08af1a05 --- /dev/null +++ b/device_operation/include/device_conv_bwd_data.hpp @@ -0,0 +1,47 @@ +#ifndef DEVICE_CONV_BWD_DATA_HPP +#define DEVICE_CONV_BWD_DATA_HPP + +#include +#include "device_base.hpp" +#include "element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(void* p_in, + const void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvBwdDataPtr = std::unique_ptr< + DeviceConvBwdData>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 0000000000..72cc021643 --- /dev/null +++ b/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,83 @@ +#include +#include "config.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using BF16 = ushort; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000..556be415f1 --- /dev/null +++ b/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,85 @@ +#include +#include "config.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#if !CK_WORKAROUND_SWDEV_325164 + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, +#endif + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 0000000000..215156398b --- /dev/null +++ b/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,82 @@ +#include +#include "config.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 0000000000..38f79bf937 --- /dev/null +++ b/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,83 @@ +#include +#include "config.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using DataType = int8_t; +using AccType = int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //#####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DataType, DataType, DataType, AccType, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances( + instances, device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index 575048399b..52c9a9f83d 100644 --- a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -32,19 +32,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; @@ -54,19 +54,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple< //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; @@ -76,19 +76,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; diff --git a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index c9af26ed39..63be85ff7a 100644 --- a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -32,19 +32,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> // clang-format on >; @@ -54,19 +54,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = std::tuple< //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> // clang-format on >; @@ -76,19 +76,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> // clang-format on >; diff --git a/example/12_conv2d_bwd_data_xdl/README.md b/example/12_conv2d_bwd_data_xdl/README.md new file mode 100644 index 0000000000..547c544445 --- /dev/null +++ b/example/12_conv2d_bwd_data_xdl/README.md @@ -0,0 +1,79 @@ +# Instructions for ```conv2d_bwd_data_xdl``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```conv2d_bwd_data_xdl``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j conv2d_bwd_data_xdl +``` + +## Run ```conv2d_bwd_data_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./bin/conv2d_bwd_data_xdl 0 1 5 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 256, 71, 71}, strides {1290496, 1, 18176, 256} +wei_k_c_y_x: dim 4, lengths {256, 256, 3, 3}, strides {2304, 1, 768, 256} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{128, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{64, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8} +arg.b_grid_desc_k0_n_k1_container_{32, 256, 8} +arg.c_grid_desc_m_n_container_{ 175232, 256} +arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 4, 2, 2, 4, 2 ) +launch_and_time_kernel: grid_dim {2738, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 2.45966 ms, 79.5597 TFlops, 169.325 GB/s +``` diff --git a/example/12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp b/example/12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp new file mode 100644 index 0000000000..7f289c1938 --- /dev/null +++ b/example/12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp @@ -0,0 +1,247 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "element_wise_operation.hpp" +#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_bwd_data.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto ConvBwdDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default; + +using DeviceConvBwdDataInstance = ck::tensor_operation::device:: + DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdDefault, // ConvolutionBackwardDataSpecialization_t + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, + 1>; // GemmCThreadTransferDstScalarPerVector + +using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdData; + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 256; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 19) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + C = std::stoi(argv[6]); + Y = std::stoi(argv[7]); + X = std::stoi(argv[8]); + Hi = std::stoi(argv[9]); + Wi = std::stoi(argv[10]); + conv_stride_h = std::stoi(argv[11]); + conv_stride_w = std::stoi(argv[12]); + conv_dilation_h = std::stoi(argv[13]); + conv_dilation_w = std::stoi(argv[14]); + in_left_pad_h = std::stoi(argv[15]); + in_left_pad_w = std::stoi(argv[16]); + in_right_pad_h = std::stoi(argv[17]); + in_right_pad_w = std::stoi(argv[18]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + }; + + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo)); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X)); + Tensor in_n_c_hi_wi_host_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor in_n_c_hi_wi_device_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * + in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + + // do GEMM + auto conv = DeviceConvBwdDataInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + std::vector{{Hi, Wi}}, + std::vector{{Y, X}}, + std::vector{{Ho, Wo}}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvBwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); + + check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result); + } +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 7e6daa7ad6..c468d753d5 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -24,6 +24,7 @@ set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp) set(CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp) set(CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp) set(CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp) +set(CONV2D_BWD_DATA_XDL_SOURCE 12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp) add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE}) @@ -36,6 +37,7 @@ add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE}) add_executable(conv2d_fwd_xdl_int8 ${CONV2D_FWD_XDL_INT8_SOURCE}) add_executable(conv3d_fwd_xdl ${CONV3D_FWD_XDL_SOURCE}) add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) +add_executable(conv2d_bwd_data_xdl ${CONV2D_BWD_DATA_XDL_SOURCE}) target_link_libraries(gemm_xdl PRIVATE host_tensor) target_link_libraries(gemm_xdl_bias_relu PRIVATE host_tensor) @@ -48,3 +50,5 @@ target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_int8 PRIVATE host_tensor) target_link_libraries(conv3d_fwd_xdl PRIVATE host_tensor) target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor) +target_link_libraries(conv2d_bwd_data_xdl PRIVATE host_tensor) + diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 18b7a89363..7187147647 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -25,6 +25,7 @@ set(PROFILER_SOURCE src/profile_conv_fwd_bias_relu_add.cpp src/profile_conv_fwd_bias_relu_atomic_add.cpp src/profile_batched_gemm.cpp + src/profile_conv_bwd_data.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) @@ -39,3 +40,4 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance) diff --git a/profiler/include/profile_conv_bwd_data_impl.hpp b/profiler/include/profile_conv_bwd_data_impl.hpp new file mode 100644 index 0000000000..019020c2ac --- /dev/null +++ b/profiler/include/profile_conv_bwd_data_impl.hpp @@ -0,0 +1,278 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_bwd_data.hpp" +#include "element_wise_operation.hpp" +#include "reference_conv_bwd_data.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ushort; +using INT8 = int8_t; +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using DeviceConvBwdDataNoOpPtr = + DeviceConvBwdDataPtr; +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector&); +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_conv_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(is_same::value || + is_same::value || + is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi_host_result(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor in_n_c_hi_wi_device_result( + f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(do_verification) + { + using ReferenceConvBwdDataInstance = + ck::tensor_operation::host::ReferenceConvBwdData; + + auto ref_conv = ReferenceConvBwdDataInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem in_device_buf(sizeof(InDataType) * + in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using DeviceConvBwdDataNoOpPtr = + ck::tensor_operation::device::DeviceConvBwdDataPtr; + + // add device Conv instances + std::vector conv_ptrs; + if constexpr(ck::is_same_v, float> && + ck::is_same_v, float> && + ck::is_same_v, float>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ushort> && + ck::is_same_v, ushort> && + ck::is_same_v, ushort>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, int8_t> && + ck::is_same_v, int8_t> && + ck::is_same_v, int8_t>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + } + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device Conv instances + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = conv_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); + + check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", out_n_k_ho_wo.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_host : ", in_n_c_hi_wi_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_device: ", in_n_c_hi_wi_device_result.mData, ",") + << std::endl; + } + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/profile_conv_bwd_data.cpp b/profiler/src/profile_conv_bwd_data.cpp new file mode 100644 index 0000000000..613c6879e6 --- /dev/null +++ b/profiler/src/profile_conv_bwd_data.cpp @@ -0,0 +1,191 @@ +#include +#include +#include +#include +#include +#include +#include "profile_conv_bwd_data_impl.hpp" + +enum ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +enum ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +int profile_conv_bwd_data(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (conv_bwd: BackwardConvolution)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: run kernel # of times (>1)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int in_layout = static_cast(std::stoi(argv[3])); + const int wei_layout = static_cast(std::stoi(argv[4])); + const int out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const int nrepeat = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_bwd_data_impl<2, + float, + float, + float, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + nrepeat, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_bwd_data_impl<2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + nrepeat, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_bwd_data_impl<2, + uint16_t, + uint16_t, + uint16_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + nrepeat, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC && + wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) + { + ck::profiler::profile_conv_bwd_data_impl<2, + int8_t, + int8_t, + int8_t, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK>( + do_verification, + init_method, + do_log, + nrepeat, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + } + else + { + throw std::runtime_error("wrong! this Conv data_type & layout is not implemented"); + } + + return 1; +} diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index c6a5a4cbc9..2ea26105a0 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -14,6 +14,7 @@ int profile_conv_fwd(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); +int profile_conv_bwd_data(int, char*[]); int main(int argc, char* argv[]) { @@ -53,6 +54,10 @@ int main(int argc, char* argv[]) { return profile_conv_fwd_bias_relu_atomic_add(argc, argv); } + else if(strcmp(argv[1], "conv_bwd") == 0) + { + return profile_conv_bwd_data(argc, argv); + } else { // clang-format off @@ -63,7 +68,8 @@ int main(int argc, char* argv[]) " conv_fwd: ForwardConvolution\n" " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" - " conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n"); + " conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n" + " conv_bwd: BackwardConvolution\n"); // clang-format on return 0; diff --git a/reference_operation/include/reference_conv_bwd_data.hpp b/reference_operation/include/reference_conv_bwd_data.hpp new file mode 100644 index 0000000000..e4366e9ace --- /dev/null +++ b/reference_operation/include/reference_conv_bwd_data.hpp @@ -0,0 +1,192 @@ +#ifndef REFERENCE_CONV_BWD_DATA_HPP +#define REFERENCE_CONV_BWD_DATA_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] +template +struct ReferenceConvBwdData : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + const Tensor& out_n_k_ho_wo_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvBwdData::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t K = arg.wei_k_c_y_x_.mDesc.GetLengths()[0]; + std::size_t Y = arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; + std::size_t X = arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; + + std::size_t Ho = arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; + std::size_t Wo = arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; + + float v_acc = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0]; + if(h_tmp % arg.conv_strides_[0] == 0) + { + int ho = h_tmp / arg.conv_strides_[0]; + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1]; + if(w_tmp % arg.conv_strides_[1] == 0) + { + int wo = w_tmp / arg.conv_strides_[1]; + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + float v_out = 0; + float v_wei = 0; + + arg.out_element_op_( + v_out, + ck::type_convert( + arg.out_n_k_ho_wo_(n, k, ho, wo))); + arg.wei_element_op_(v_wei, + ck::type_convert( + arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_out * v_wei; + } + } + } + } + } + } + } + + float v_in; + arg.in_element_op_(v_in, v_acc); + arg.in_n_c_hi_wi_(n, c, hi, wi) = ck::type_convert(v_in); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.in_n_c_hi_wi_.mDesc.GetLengths()[0], + arg.in_n_c_hi_wi_.mDesc.GetLengths()[1], + arg.in_n_c_hi_wi_.mDesc.GetLengths()[2], + arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvBwdData" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/test/conv2d_bwd_data/main.cpp b/test/conv2d_bwd_data/main.cpp new file mode 100644 index 0000000000..72ed6ee074 --- /dev/null +++ b/test/conv2d_bwd_data/main.cpp @@ -0,0 +1,319 @@ +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "device_conv_bwd_data.hpp" +#include "element_wise_operation.hpp" +#include "reference_conv_bwd_data.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ushort; +using INT8 = int8_t; +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_bwd_data_instance { + +using DeviceConvBwdDataNoOpPtr = + DeviceConvBwdDataPtr; +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector&); +void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector&); +} // namespace device_conv2d_bwd_data_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +template +static bool check_out(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-6; + + for(int i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + return false; + } + } + + return true; +} + +int main(int argc, char* argv[]) +{ + int data_type = 0; + int init_method = 0; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 3) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + } + else if(argc == 18) + { + data_type = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + conv_stride_h = std::stoi(argv[10]); + conv_stride_w = std::stoi(argv[11]); + conv_dilation_h = std::stoi(argv[12]); + conv_dilation_w = std::stoi(argv[13]); + in_left_pad_h = std::stoi(argv[14]); + in_left_pad_w = std::stoi(argv[15]); + in_right_pad_h = std::stoi(argv[16]); + in_right_pad_w = std::stoi(argv[17]); + } + else + { + printf("arg1: data type (0=fp32 )\n"); + printf("arg2: verification (0=no, 1=yes)\n"); + printf("arg3: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg4: run kernel # of times (>1)\n"); + printf("arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + auto Run = [&](auto input_type, auto wei_type, auto out_type) { + using InDataType = decltype(input_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using ReferenceConvBwdInstance = + ck::tensor_operation::host::ReferenceConvBwdData; + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector input_spatial_lengths{{Hi, Wi}}; + const std::vector filter_spatial_lengths{{Y, X}}; + const std::vector output_spatial_lengths{{Ho, Wo}}; + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + }; + + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo)); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X)); + Tensor in_n_c_hi_wi_host_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor in_n_c_hi_wi_device_result(f_host_tensor_descriptor(N, C, Hi, Wi)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * + in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + + in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1{5}); + in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data()); + + // get host result + { + auto ref_conv = ReferenceConvBwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device:: + DeviceConvBwdDataPtr; + + // add device Conv instances + std::vector conv_ptrs; + + if constexpr(ck::is_same_v, float> && + ck::is_same_v, float> && + ck::is_same_v, float>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t> && + ck::is_same_v, ck::half_t>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, ushort> && + ck::is_same_v, ushort> && + ck::is_same_v, ushort>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + } + else if constexpr(ck::is_same_v, int8_t> && + ck::is_same_v, int8_t> && + ck::is_same_v, int8_t>) + { + ck::tensor_operation::device::device_conv2d_bwd_data_instance:: + add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + } + + if(conv_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + // profile device Conv instances + bool success = true; + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(conv_ptr->IsSupportedArgument(argument_ptr.get())) + { + auto invoker_ptr = conv_ptr->MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get(), 1); + + in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); + + if(!check_out(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result)) + { + std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; + success = false; + } + else + { + std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; + } + } + else + { + std::cout << "Not support Info: " << conv_ptr->GetTypeString() << std::endl; + } + } + + if(success) + { + std::cout << "test conv2d bwd : Pass" << std::endl; + } + else + { + std::cout << "test conv2d bwd: Fail " << std::endl; + } + }; + + if(data_type == 0) + { + Run(float(), float(), F32()); + } + else if(data_type == 1) + { + Run(F16(), F16(), F16()); + } + else if(data_type == 2) + { + Run(BF16(), BF16(), BF16()); + } + else if(data_type == 3) + { + Run(INT8(), INT8(), INT8()); + } + else + { + return 1; + } + + return 0; +}