From 8890cc207d1b413592bceec9cd70384e13d32fe1 Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 4 Feb 2022 12:29:58 +0800 Subject: [PATCH] References for conv2d fwd bias relu and add (#75) * add reference * clean up * add reference for conv * rename Co-authored-by: ltqin Co-authored-by: Chao Liu [ROCm/composable_kernel commit: 690c75a7eb7012bf0fd6fb3f6e129e83fbcbdb53] --- example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp | 127 ++++++------ .../conv2d_fwd_xdl_bias_relu.cpp | 129 ++++++------ .../conv2d_fwd_xdl_bias_relu_add.cpp | 137 ++++++------- example/CMakeLists.txt | 1 + host/include/reference_conv_fwd.hpp | 166 ++++++++++++++++ .../reference_conv_fwd_bias_activation.hpp | 172 ++++++++++++++++ ...reference_conv_fwd_bias_activation_add.hpp | 183 ++++++++++++++++++ 7 files changed, 706 insertions(+), 209 deletions(-) create mode 100644 host/include/reference_conv_fwd.hpp create mode 100644 host/include/reference_conv_fwd_bias_activation.hpp create mode 100644 host/include/reference_conv_fwd_bias_activation_add.hpp diff --git a/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp index ad428e2ef2..310de70b25 100644 --- a/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp +++ b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp @@ -11,8 +11,9 @@ #include "host_tensor_generator.hpp" #include "device_tensor.hpp" #include "tensor_layout.hpp" -#include "device_operation/include/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_fwd.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -33,65 +34,53 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; +// clang-format off using DeviceConvFwdInstance = ck::tensor_operation::device:: - DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K - // clang-format off -// | InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -// | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format on -template -void host_verify(const Tensor& in, - const Tensor& wei, - Tensor& out, - const std::vector& conv_strides, - const std::vector& conv_dilations, - const std::vector& in_left_pads, - const std::vector&, - const InElementOp& in_element_op, - const WeiElementOp& wei_element_op, - const OutElementOp& out_element_op) -{ - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0]; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - v += in_element_op(static_cast(in(n, c, hi, wi))) * - wei_element_op(static_cast(wei(k, c, y, x))); - } - } - } - } - double v2 = out(n, k, ho, wo); - - out_element_op(v2, v); - - out(n, k, ho, wo) = v2; - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); -} +using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; int main(int argc, char* argv[]) { @@ -265,16 +254,20 @@ int main(int argc, char* argv[]) if(do_verification) { - host_verify(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + auto refConv = ReferenceConvFwdInstance{}; + auto refInvoker = refConv.MakeInvoker(); + + auto refArgument = refConv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + refInvoker.Run(refArgument); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); diff --git a/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp b/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp index aa2605bbdf..79bd332709 100644 --- a/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp +++ b/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp @@ -11,8 +11,9 @@ #include "host_tensor_generator.hpp" #include "device_tensor.hpp" #include "tensor_layout.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_fwd_bias_activation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -37,63 +38,53 @@ static constexpr auto ConvFwdDefault = // clang-format off using DeviceConvFwdInstance = ck::tensor_operation::device:: - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K - // clang-format off -// | InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -// | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + MemorySet, // OutGlobalMemoryDataOperation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format on -template -void host_reference_calculation(const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - const Tensor& bias_k, - const std::vector& conv_strides, - const std::vector& conv_dilations, - const std::vector& in_left_pads, - const std::vector& /* in_right_pads */, - const InElementOp& in_element_op, - const WeiElementOp& wei_element_op, - const OutElementOp& out_element_op) -{ - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - double v = 0; - for(int c = 0; c < wei_k_c_y_x.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < wei_k_c_y_x.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0]; - for(int x = 0; x < wei_k_c_y_x.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1]; - if(hi >= 0 && hi < in_n_c_hi_wi.mDesc.GetLengths()[2] && wi >= 0 && - wi < in_n_c_hi_wi.mDesc.GetLengths()[3]) - { - v += in_element_op(static_cast(in_n_c_hi_wi(n, c, hi, wi))) * - wei_element_op(static_cast(wei_k_c_y_x(k, c, y, x))); - } - } - } - } - - out_n_k_ho_wo(n, k, ho, wo) = out_element_op(v, bias_k(k)); - }; - - make_ParallelTensorFunctor(f_nchw, - out_n_k_ho_wo.mDesc.GetLengths()[0], - out_n_k_ho_wo.mDesc.GetLengths()[1], - out_n_k_ho_wo.mDesc.GetLengths()[2], - out_n_k_ho_wo.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); -} +using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation; int main(int argc, char* argv[]) { @@ -277,17 +268,21 @@ int main(int argc, char* argv[]) if(do_verification) { - host_reference_calculation(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - bias_k, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + auto refConv = ReferenceConvFwdInstance{}; + auto refInvoker = refConv.MakeInvoker(); + + auto refArgument = refConv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + refInvoker.Run(refArgument); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); diff --git a/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp b/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp index 1353b65248..2b1414b05b 100644 --- a/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp +++ b/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -11,8 +11,9 @@ #include "host_tensor_generator.hpp" #include "device_tensor.hpp" #include "tensor_layout.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_fwd_bias_activation_add.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -35,70 +36,52 @@ static constexpr auto ConvFwdDefault = // clang-format off using DeviceConvFwdInstance = ck::tensor_operation::device:: - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K -// | InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -// | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>; + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format on -template -void host_reference_calculation(const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - const Tensor& bias_k, - const Tensor& resi_n_k_ho_wo, - const std::vector& conv_strides, - const std::vector& conv_dilations, - const std::vector& in_left_pads, - const std::vector& /* in_right_pads */, - const InElementOp& in_element_op, - const WeiElementOp& wei_element_op, - const OutElementOp& out_element_op) -{ - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - double v = 0; - for(int c = 0; c < wei_k_c_y_x.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < wei_k_c_y_x.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0]; - for(int x = 0; x < wei_k_c_y_x.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1]; - if(hi >= 0 && hi < in_n_c_hi_wi.mDesc.GetLengths()[2] && wi >= 0 && - wi < in_n_c_hi_wi.mDesc.GetLengths()[3]) - { - v += in_element_op(static_cast(in_n_c_hi_wi(n, c, hi, wi))) * - wei_element_op(static_cast(wei_k_c_y_x(k, c, y, x))); - } - } - } - } - - double v2 = out_n_k_ho_wo(n, k, ho, wo); - - out_element_op(v2, - v, - static_cast(bias_k(k)), - static_cast(resi_n_k_ho_wo(n, k, ho, wo))); - - out_n_k_ho_wo(n, k, ho, wo) = v2; - }; - - make_ParallelTensorFunctor(f_nchw, - out_n_k_ho_wo.mDesc.GetLengths()[0], - out_n_k_ho_wo.mDesc.GetLengths()[1], - out_n_k_ho_wo.mDesc.GetLengths()[2], - out_n_k_ho_wo.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); -} +using ReferenceConvFwdInstance = + ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add; int main(int argc, char* argv[]) { @@ -292,18 +275,22 @@ int main(int argc, char* argv[]) if(do_verification) { - host_reference_calculation(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - bias_k, - resi_n_k_ho_wo, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + auto refConv = ReferenceConvFwdInstance{}; + auto refInvoker = refConv.MakeInvoker(); + + auto refArgument = refConv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + refInvoker.Run(refArgument); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 6f231bcdf0..c25e78bf29 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -2,6 +2,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}/host/host_tensor/include ${PROJECT_SOURCE_DIR}/host/device/include + ${PROJECT_SOURCE_DIR}/host/include ${PROJECT_SOURCE_DIR}/device_operation/include ${PROJECT_SOURCE_DIR}/composable_kernel/include ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility diff --git a/host/include/reference_conv_fwd.hpp b/host/include/reference_conv_fwd.hpp new file mode 100644 index 0000000000..a92ed95b3c --- /dev/null +++ b/host/include/reference_conv_fwd.hpp @@ -0,0 +1,166 @@ +#ifndef REFERENCE_CONV_FWD_HPP +#define REFERENCE_CONV_FWD_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] +template +struct ReferenceConvFwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + Tensor& out_n_k_ho_wo_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v = 0; + for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - + arg.in_left_pads_[0]; + for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - + arg.in_left_pads_[1]; + if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && + wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + v += arg.in_element_op_( + ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))) * + arg.wei_element_op_( + ck::type_convert(arg.wei_k_c_y_x_(k, c, y, x))); + } + } + } + } + + arg.out_n_k_ho_wo_(n, k, ho, wo) = + ck::type_convert(arg.out_element_op_(v)); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/host/include/reference_conv_fwd_bias_activation.hpp b/host/include/reference_conv_fwd_bias_activation.hpp new file mode 100644 index 0000000000..d65bba1a88 --- /dev/null +++ b/host/include/reference_conv_fwd_bias_activation.hpp @@ -0,0 +1,172 @@ +#ifndef REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP +#define REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) +template +struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + bias_k_{bias_k}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + Tensor& out_n_k_ho_wo_; + const Tensor& bias_k_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd_Bias_Activation::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v = 0; + for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - + arg.in_left_pads_[0]; + for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - + arg.in_left_pads_[1]; + if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && + wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + v += arg.in_element_op_( + ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))) * + arg.wei_element_op_( + ck::type_convert(arg.wei_k_c_y_x_(k, c, y, x))); + } + } + } + } + + arg.out_n_k_ho_wo_(n, k, ho, wo) = + ck::type_convert(arg.out_element_op_(v, arg.bias_k_(k))); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd_Bias_Activation" + << std::endl; + // clang-format on + + return str.str(); + } +}; +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/host/include/reference_conv_fwd_bias_activation_add.hpp b/host/include/reference_conv_fwd_bias_activation_add.hpp new file mode 100644 index 0000000000..eb4b708c12 --- /dev/null +++ b/host/include/reference_conv_fwd_bias_activation_add.hpp @@ -0,0 +1,183 @@ +#ifndef REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP +#define REFERENCE_CONV2D_FWD_BIAS_ACTIVATION_ADD_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] +template +struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const Tensor& resi_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + bias_k_{bias_k}, + resi_n_k_ho_wo_{resi_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + const Tensor& wei_k_c_y_x_; + Tensor& out_n_k_ho_wo_; + const Tensor& bias_k_; + const Tensor& resi_n_k_ho_wo_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvFwd_Bias_Activation_Add::Argument; + + float Run(const Argument& arg) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v = 0; + for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - + arg.in_left_pads_[0]; + for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - + arg.in_left_pads_[1]; + if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && + wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + v += arg.in_element_op_( + ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))) * + arg.wei_element_op_( + ck::type_convert(arg.wei_k_c_y_x_(k, c, y, x))); + } + } + } + } + + float v2 = ck::type_convert(arg.out_n_k_ho_wo_(n, k, ho, wo)); + + arg.out_element_op_(v2, + v, + ck::type_convert(arg.bias_k_(k)), + ck::type_convert(arg.resi_n_k_ho_wo_(n, k, ho, wo))); + + arg.out_n_k_ho_wo_(n, k, ho, wo) = ck::type_convert(v2); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], + arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const Tensor& bias_k, + const Tensor& resi_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd_Bias_Activation_Add" + << std::endl; + // clang-format on + + return str.str(); + } +}; +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif