#ifndef DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP #define DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP #include #include #include "device.hpp" #include "device_reduce.hpp" #include "device_reduce_common.hpp" #include "gridwise_2d_reduction_blockwise.hpp" namespace ck { namespace tensor_operation { namespace device { template struct DeviceReduceBlockWiseSecondCall : public DeviceReduce { static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, "Invalid thread cluster size assignments!"); using IndexDataType = int32_t; static constexpr bool BetaIsZero = NeedIndices; static_assert( std::is_same::value, "InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!"); using InvariantDims = decltype(get_invariant_dims()); static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static auto MakeSrc2dDescriptor(const std::vector& inLengths, const std::vector& inStrides) { const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<2>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<2>{}); const auto in_grid_desc_m_k = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_K = math::integer_least_multiple(innerLen, K_BlockTileSize) - innerLen; auto in_grid_desc_m_k_padded = transform_tensor_descriptor(in_grid_desc_m_k, make_tuple(make_right_pad_transform(outerLen, inPad_M), make_right_pad_transform(innerLen, inPad_K)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); return (in_grid_desc_m_k_padded); }; static auto MakeDst1dDescriptor(const std::vector& outLengths, const std::vector& outStrides) { const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto out_grid_desc_m = transform_tensor_descriptor( outDesc, make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), make_tuple(Sequence<0>{})); const auto outerLen = out_grid_desc_m.GetLength(Number<0>{}); const auto outPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; auto out_grid_desc_m_padded = transform_tensor_descriptor(out_grid_desc_m, make_tuple(make_right_pad_transform(outerLen, outPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); return (out_grid_desc_m_padded); }; struct Argument : public BaseArgument { Argument(const std::vector& inLengths, const std::vector& inStrides, const std::vector& outLengths, const std::vector& outStrides, float alpha, float beta, const InDataType* in_dev, OutDataType* out_dev, IndexDataType* out_indices_dev, AccDataType* workspace_dev, const InElementwiseOperation& in_elementwise_op, const AccElementwiseOperation& acc_elementwise_op) : in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev} { inLengths_ = inLengths; inStrides_ = inStrides; outLengths_ = outLengths; outStrides_ = outStrides; in_elementwise_op_ = in_elementwise_op; acc_elementwise_op_ = acc_elementwise_op; alpha_ = static_cast(alpha); beta_ = static_cast(beta); invariant_total_length = inLengths[0]; reduce_total_length = inLengths[1]; invariant_lowest_length = inLengths[0]; reduce_lowest_length = inLengths[1]; gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / M_BlockTileSize; size_t ws_buf2_bytes_offset = math::integer_least_multiple( invariant_total_length * reduce_total_length * sizeof(AccDataType), 64); if constexpr(NeedIndices) workspace_indices_dev_ = reinterpret_cast( reinterpret_cast(workspace_dev) + ws_buf2_bytes_offset); else workspace_indices_dev_ = nullptr; } std::vector inLengths_; std::vector inStrides_; std::vector outLengths_; std::vector outStrides_; AccDataType alpha_; OutDataType beta_; const InDataType* in_dev_; OutDataType* out_dev_; IndexDataType* out_indices_dev_; IndexDataType* workspace_indices_dev_; InElementwiseOperation in_elementwise_op_; AccElementwiseOperation acc_elementwise_op_; int invariant_lowest_length; int reduce_lowest_length; size_t invariant_total_length; size_t reduce_total_length; size_t gridSize; }; struct Invoker : public BaseInvoker { float Run(const Argument& arg, int nrepeat = 1) { const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor( arg.inLengths_, arg.inStrides_); const auto out_grid_desc_m = DeviceReduceBlockWiseSecondCall::MakeDst1dDescriptor( arg.outLengths_, arg.outStrides_); using InGridDesc_M_K = decltype(in_grid_desc_m_k); using OutGridDesc_M = decltype(out_grid_desc_m); using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise; float avg_time = 0; const auto kernel = kernel_reduce_blockwise_second_call; avg_time = launch_and_time_kernel(kernel, nrepeat, dim3(arg.gridSize), dim3(BlockSize), 0, in_grid_desc_m_k, out_grid_desc_m, arg.in_elementwise_op_, arg.acc_elementwise_op_, arg.alpha_, arg.in_dev_, arg.beta_, arg.out_dev_, arg.workspace_indices_dev_, arg.out_indices_dev_); return (avg_time); }; float Run(const BaseArgument* p_arg, int nrepeat = 1) override { return Run(*dynamic_cast(p_arg), nrepeat); }; }; bool IsSupportedArgument(const BaseArgument* p_arg) override { const Argument* pArg = dynamic_cast(p_arg); if constexpr(InSrcVectorDim == 0) return (false); if(pArg->reduce_lowest_length % InSrcVectorSize != 0) return (false); // To improve if(pArg->invariant_lowest_length % OutDstVectorSize != 0) return (false); // cases with very small reduce_total_length should be handled by the ThreadWise method if(pArg->reduce_total_length / KThreadSliceSize < 2) return (false); return (true); }; std::unique_ptr MakeArgumentPointer(const std::vector& inLengths, const std::vector& inStrides, const std::vector& outLengths, const std::vector& outStrides, float alpha, float beta, const void* in_dev, void* out_dev, void* out_indices_dev, void* workspace_dev, const InElementwiseOperation& in_elementwise_op, const AccElementwiseOperation& acc_elementwise_op) override { return std::make_unique(inLengths, inStrides, outLengths, outStrides, alpha, beta, static_cast(in_dev), static_cast(out_dev), static_cast(out_indices_dev), static_cast(workspace_dev), in_elementwise_op, acc_elementwise_op); }; std::unique_ptr MakeInvokerPointer() override { return std::make_unique(); }; std::string GetTypeString() const override { auto str = std::stringstream(); // clang-format off str << "DeviceReduceBlockWiseSecondCall<" << BlockSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; // clang-format on return str.str(); } }; } // namespace device } // namespace tensor_operation } // namespace ck #endif