#ifndef DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP #define DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP #include #include #include "device.hpp" #include "device_reduce.hpp" #include "device_reduce_common.hpp" #include "gridwise_2d_reduction_multiblock_partial_reduce.hpp" namespace ck { namespace tensor_operation { namespace device { template struct DeviceReduceMultiBlockPartialReduce : public DeviceReduce { static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, "Invalid thread cluster size assignments!"); static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!"); using IndexDataType = int32_t; using InvariantDims = decltype(get_invariant_dims()); static constexpr index_t srcDims = Rank; static constexpr index_t dstDims = (InvariantDims::Size() == 0) ? 1 : InvariantDims::Size(); static constexpr bool reduceAllDims = (InvariantDims::Size() == 0); static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; size_t GetWorkspaceSizeInBytes(const std::vector& inLengths) override { size_t invariant_total_length; size_t reduce_total_length; std::tie(invariant_total_length, reduce_total_length) = get_2d_lengths(inLengths); int iterations = 1; while(true) { int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / (K_BlockTileSize * iterations); // we want the blkGroupSize be not more than 128 if(testBlkGroupSize <= 128) break; iterations++; }; int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / (K_BlockTileSize * iterations); size_t workspace_size = invariant_total_length * blkGroupSize; size_t wsSizeInBytes = !NeedIndices ? workspace_size * sizeof(AccDataType) : workspace_size * (sizeof(AccDataType) + sizeof(int)) + 64 + sizeof(int); return (wsSizeInBytes); }; bool HasFurtherCall() override { return (true); }; static auto MakeSrc2dDescriptor(const std::vector& inLengths, const std::vector& inStrides, int blkGroupSize, int kBlockTileIterations) { const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto in_grid_desc_m_k = [&]() { if constexpr(reduceAllDims) { const auto one_dim_inDesc = transform_tensor_descriptor( inDesc, make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(typename arithmetic_sequence_gen<0, srcDims, 1>::type{}), make_tuple(Sequence<0>{})); return transform_tensor_descriptor(one_dim_inDesc, make_tuple(make_unmerge_transform(make_tuple( 1, one_dim_inDesc.GetLength(Number<0>{})))), make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); } else { const auto toReduceDimLengths = make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); const auto invariantDimLengths = make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); return transform_tensor_descriptor( inDesc, make_tuple(make_merge_transform(invariantDimLengths), make_merge_transform(toReduceDimLengths)), make_tuple(InvariantDims{}, ReduceDims{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } }(); const auto outerLen = in_grid_desc_m_k.GetLength(Number<0>{}); const auto innerLen = in_grid_desc_m_k.GetLength(Number<1>{}); const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations; const auto inPad_M = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; const auto inPad_K = reduceSizePerBlock * blkGroupSize - innerLen; auto in_grid_desc_m_k_padded = transform_tensor_descriptor(in_grid_desc_m_k, make_tuple(make_right_pad_transform(outerLen, inPad_M), make_right_pad_transform(innerLen, inPad_K)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); return (in_grid_desc_m_k_padded); }; static auto MakeWorkspace2dDescriptor(int outerLen, int blkGroupSize) { auto ws_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(outerLen, blkGroupSize)); const auto wsPad = math::integer_least_multiple(outerLen, M_BlockTileSize) - outerLen; auto ws_desc_m_k_padded = transform_tensor_descriptor(ws_desc_m_k, make_tuple(make_right_pad_transform(outerLen, wsPad), make_pass_through_transform(blkGroupSize)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); return (ws_desc_m_k_padded); }; struct Argument : public BaseArgument { Argument(const std::vector& inLengths, const std::vector& inStrides, const std::vector& outLengths, const std::vector& outStrides, float alpha, float beta, const InDataType* in_dev, OutDataType* out_dev, IndexDataType* out_indices_dev, AccDataType* workspace_dev, const InElementwiseOperation& in_elementwise_op, const AccElementwiseOperation& acc_elementwise_op) : in_dev_{in_dev}, out_dev_{out_dev}, out_indices_dev_{out_indices_dev}, workspace_dev_{workspace_dev} { inLengths_ = inLengths; inStrides_ = inStrides; outLengths_ = outLengths; outStrides_ = outStrides; in_elementwise_op_ = in_elementwise_op; acc_elementwise_op_ = acc_elementwise_op; alpha_ = static_cast(alpha); beta_ = static_cast(beta); std::tie(invariant_total_length, reduce_total_length) = get_2d_lengths(inLengths); if constexpr(InvariantDims::Size() == 0) invariant_lowest_length = 1; else invariant_lowest_length = inLengths[InvariantDims::At(InvariantDims::Size() - 1)]; reduce_lowest_length = inLengths[ReduceDims::At(ReduceDims::Size() - 1)]; int iterations = 1; while(true) { int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / (K_BlockTileSize * iterations); // we want the blkGroupSize be not more than 128 if(testBlkGroupSize <= 128) break; iterations++; }; blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / (K_BlockTileSize * iterations); kBlockTileIterations = iterations; gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / M_BlockTileSize * blkGroupSize; size_t ws_buf2_bytes_offset = math::integer_least_multiple( invariant_total_length * blkGroupSize * sizeof(AccDataType), 64); if constexpr(NeedIndices) workspace_indices_dev_ = reinterpret_cast( reinterpret_cast(workspace_dev_) + ws_buf2_bytes_offset); else workspace_indices_dev_ = nullptr; } std::vector inLengths_; std::vector inStrides_; std::vector outLengths_; std::vector outStrides_; AccDataType alpha_; OutDataType beta_; const InDataType* in_dev_; OutDataType* out_dev_; IndexDataType* out_indices_dev_; AccDataType* workspace_dev_; IndexDataType* workspace_indices_dev_; InElementwiseOperation in_elementwise_op_; AccElementwiseOperation acc_elementwise_op_; int invariant_lowest_length; int reduce_lowest_length; size_t invariant_total_length; size_t reduce_total_length; index_t blkGroupSize; index_t kBlockTileIterations; size_t gridSize; }; struct Invoker : public BaseInvoker { float Run(const Argument& arg, int nrepeat = 1) { const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor( arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); const auto ws_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeWorkspace2dDescriptor( arg.invariant_total_length, arg.blkGroupSize); using InGridDesc_M_K = decltype(in_grid_desc_m_k); using WorkspaceDesc_M_K = decltype(ws_desc_m_k); using GridwiseReduce = GridwiseReduction_mk_to_mk_multiblock_partial_reduce; float avg_time = 0; const auto kernel = kernel_partial_reduce_multiblock; avg_time = launch_and_time_kernel(kernel, nrepeat, dim3(arg.gridSize), dim3(BlockSize), 0, in_grid_desc_m_k, ws_desc_m_k, arg.in_elementwise_op_, arg.acc_elementwise_op_, arg.blkGroupSize, arg.kBlockTileIterations, arg.in_dev_, arg.workspace_dev_, arg.workspace_indices_dev_); return (avg_time); }; float Run(const BaseArgument* p_arg, int nrepeat = 1) override { return Run(*dynamic_cast(p_arg), nrepeat); }; }; bool IsSupportedArgument(const BaseArgument* p_arg) override { const Argument* pArg = dynamic_cast(p_arg); if constexpr(OutDstVectorSize != 1) return (false); if constexpr(InSrcVectorDim == 0) { if constexpr(InvariantDims::Size() == 0) return (false); if(pArg->inStrides_[InvariantDims::At(InvariantDims::Size() - 1)] != 1) return (false); if(pArg->invariant_lowest_length % InSrcVectorSize != 0) return (false); } else { if(pArg->inStrides_[ReduceDims::At(ReduceDims::Size() - 1)] != 1) return (false); if(pArg->reduce_lowest_length % InSrcVectorSize != 0) return (false); }; // cases with small reduce_total_length should be handled by the BlockWise method if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize) return (false); return (true); }; std::vector GetWorkspace2dLengths(const BaseArgument* p_arg) override { const Argument* pArg = dynamic_cast(p_arg); return ( std::vector{static_cast(pArg->invariant_total_length), pArg->blkGroupSize}); }; std::unique_ptr MakeArgumentPointer(const std::vector& inLengths, const std::vector& inStrides, const std::vector& outLengths, const std::vector& outStrides, float alpha, float beta, const void* in_dev, void* out_dev, void* out_indices_dev, void* workspace_dev, const InElementwiseOperation& in_elementwise_op, const AccElementwiseOperation& acc_elementwise_op) override { return std::make_unique(inLengths, inStrides, outLengths, outStrides, alpha, beta, static_cast(in_dev), static_cast(out_dev), static_cast(out_indices_dev), static_cast(workspace_dev), in_elementwise_op, acc_elementwise_op); }; std::unique_ptr MakeInvokerPointer() override { return std::make_unique(); }; std::string GetTypeString() const override { auto str = std::stringstream(); // clang-format off str << "DeviceReduceMultiBlockPartialReduce<" << BlockSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; // clang-format on return str.str(); } }; } // namespace device } // namespace tensor_operation } // namespace ck #endif