diff --git a/device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..46dd86e591 --- /dev/null +++ b/device_operation/include/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,710 @@ +#ifndef DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_WRW_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_backward_weight.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r4r2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvWrw +{ + using DeviceOp = DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + output_spatial_lengths_{output_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + M01_, + N01_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation a_element_op_; + OutElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector output_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, int nrepeat = 1) + { + ShowInfo(arg); + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(nrepeat > 0) + { + ave_time = + launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + if(kbatch > 1 || nrepeat <= 0) + { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + launch_kernel(kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4r2< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_conv_backward_weight.hpp b/device_operation/include/device_conv_backward_weight.hpp new file mode 100644 index 0000000000..c025fa61a5 --- /dev/null +++ b/device_operation/include/device_conv_backward_weight.hpp @@ -0,0 +1,47 @@ +#ifndef DEVICE_CONV_WRW_HPP +#define DEVICE_CONV_WRW_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvWrw : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvWrwPtr = std::unique_ptr< + DeviceConvWrw>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/example/13_conv2d_backward_weight_xdl/README.md b/example/13_conv2d_backward_weight_xdl/README.md new file mode 100644 index 0000000000..16e9bbc455 --- /dev/null +++ b/example/13_conv2d_backward_weight_xdl/README.md @@ -0,0 +1,58 @@ +# Instructions for ```conv2d_wrw_xdl``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```conv2d_wrw_xdl``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j conv2d_wrw_xdl +``` + +## Run ```conv2d_wrw_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4: is show log (0=no, 1=yes) +#arg5 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx, split-k +./example/conv2d_fwd_xdl 0 1 5 0 4 +``` + +Result +``` +in_n_c_hi_wi: dim 4, lengths {128, 1024, 14, 14}, strides {200704, 1, 14336, 1024} +wei_k_c_y_x: dim 4, lengths {256, 1024, 3, 3}, strides {9216, 1, 3072, 1024} +out_n_k_ho_wo: dim 4, lengths {128, 256, 6, 6}, strides {9216, 1, 1536, 256} +arg.a_grid_desc_kbatch_k0_m_k1_{4, 144, 256, 8} +arg.b_grid_desc_kbatch_k0_n_k1_{4, 144, 9216, 8} +arg.c_grid_desc_m_n_{ 256, 9216} +launch_and_time_kernel: grid_dim {576, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 0.401084 ms, 54.2112 TFlops, 145.75 GB/s +``` diff --git a/example/13_conv2d_backward_weight_xdl/main.cpp b/example/13_conv2d_backward_weight_xdl/main.cpp new file mode 100644 index 0000000000..4141587583 --- /dev/null +++ b/example/13_conv2d_backward_weight_xdl/main.cpp @@ -0,0 +1,289 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "element_wise_operation.hpp" +#include "device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_backward_weight.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +using DeviceConvWrWInstance = ck::tensor_operation::device:: + DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceConvWrwInstance = ck::tensor_operation::host:: + ReferenceConvWrw; + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + int do_log = 0; + int split_k = 4; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 1024; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 14; + ck::index_t Wi = 14; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc == 21) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + + N = std::stoi(argv[6]); + K = std::stoi(argv[7]); + C = std::stoi(argv[8]); + Y = std::stoi(argv[9]); + X = std::stoi(argv[10]); + Hi = std::stoi(argv[11]); + Wi = std::stoi(argv[12]); + conv_stride_h = std::stoi(argv[13]); + conv_stride_w = std::stoi(argv[14]); + conv_dilation_h = std::stoi(argv[15]); + conv_dilation_w = std::stoi(argv[16]); + in_left_pad_h = std::stoi(argv[17]); + in_left_pad_w = std::stoi(argv[18]); + in_right_pad_h = std::stoi(argv[19]); + in_right_pad_w = std::stoi(argv[20]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4: is show log (0=no, 1=yes)\n"); + printf("arg5: split-k \n"); + printf("arg6 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = [](std::size_t N_, + std::size_t C_, + std::size_t H, + std::size_t W, + auto layout) { + if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x_host_result(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor wei_k_c_y_x_device_result( + f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + } + wei_k_c_y_x_device_result.GenerateTensorValue(GeneratorTensor_1{0}); + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x_device_result.mData.data()); + + // do GEMM + auto conv = DeviceConvWrWInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + std::vector{{Hi, Wi}}, + std::vector{{Y, X}}, + std::vector{{Ho, Wo}}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + if(!conv.IsSupportedArgument(argument)) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto ref_conv = ReferenceConvWrwInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result); + } +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index d26da43f57..1f7b7ad7bd 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -24,6 +24,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fw set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp) set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp) set(CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp) +set(CONV2D_WRW_XDL_SOURCE 13_conv2d_backward_weight_xdl/main.cpp) set(CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp) set(CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp) set(CONV2D_BWD_DATA_XDL_SOURCE 12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp) @@ -39,6 +40,7 @@ add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURC add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE}) add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE}) add_executable(conv2d_fwd_xdl_int8 ${CONV2D_FWD_XDL_INT8_SOURCE}) +add_executable(conv2d_wrw_xdl ${CONV2D_WRW_XDL_SOURCE}) add_executable(conv3d_fwd_xdl ${CONV3D_FWD_XDL_SOURCE}) add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) add_executable(conv2d_bwd_data_xdl ${CONV2D_BWD_DATA_XDL_SOURCE}) @@ -54,6 +56,7 @@ target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor) target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_int8 PRIVATE host_tensor) +target_link_libraries(conv2d_wrw_xdl PRIVATE host_tensor) target_link_libraries(conv3d_fwd_xdl PRIVATE host_tensor) target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor) target_link_libraries(conv2d_bwd_data_xdl PRIVATE host_tensor) diff --git a/reference_operation/include/reference_conv_backward_weight.hpp b/reference_operation/include/reference_conv_backward_weight.hpp new file mode 100644 index 0000000000..d36a29b3a0 --- /dev/null +++ b/reference_operation/include/reference_conv_backward_weight.hpp @@ -0,0 +1,177 @@ +#ifndef REFERENCE_CONV_WRW_HPP +#define REFERENCE_CONV_WRW_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] +template +struct ReferenceConvWrw : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : in_n_c_hi_wi_{in_n_c_hi_wi}, + wei_k_c_y_x_{wei_k_c_y_x}, + out_n_k_ho_wo_{out_n_k_ho_wo}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + } + + const Tensor& in_n_c_hi_wi_; + Tensor& wei_k_c_y_x_; + const Tensor& out_n_k_ho_wo_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceConvWrw::Argument; + + float Run(const Argument& arg) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + auto f_kcyx = [&](auto k, auto c, auto y, auto x) { + float v_acc = 0; + for(int n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho) + { + int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] - + arg.in_left_pads_[I0]; + for(int wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo) + { + int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] - + arg.in_left_pads_[I1]; + if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && + wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + { + float v_out; + float v_in; + + arg.out_element_op_( + v_out, + ck::type_convert(arg.out_n_k_ho_wo_(n, k, ho, wo))); + arg.in_element_op_( + v_in, ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))); + + v_acc += v_out * v_in; + } + } + } + } + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.wei_k_c_y_x_(k, c, y, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kcyx, + arg.wei_k_c_y_x_.mDesc.GetLengths()[0], + arg.wei_k_c_y_x_.mDesc.GetLengths()[1], + arg.wei_k_c_y_x_.mDesc.GetLengths()[2], + arg.wei_k_c_y_x_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceConvFwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif