// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" #include namespace ck_tile { template CK_TILE_HOST void reference_reduce(const HostTensor& x_m_n, HostTensor& y_m, ReduceOp reduce_op) { auto f = [&](auto m) { const int N = x_m_n.mDesc.get_lengths()[1]; ComputeDataType v_acc = reduce_op.template GetIdentityValue(); for(int n = 0; n < N; ++n) { const ComputeDataType v_a = type_convert(x_m_n(m, n)); v_acc = reduce_op(v_acc, v_a); } y_m(m) = ck_tile::type_convert(v_acc); }; make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); } // Generic reference reduce for arbitrary dimensions template CK_TILE_HOST void reference_reduce(const HostTensor& x_tensor, HostTensor& y_tensor, ReduceOp reduce_op, KeptDim kept_dim, ReduceDims reduce_dims) { const auto& x_lengths = x_tensor.mDesc.get_lengths(); const auto& x_strides = x_tensor.mDesc.get_strides(); const auto& y_strides = y_tensor.mDesc.get_strides(); // Calculate total kept elements (product of all kept dimension lengths) index_t total_kept_elements = 1; static_for<0, kept_dim.size(), 1>{}( [&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; }); // Calculate total reduce elements (product of all reduce dimension lengths) index_t total_reduce_elements = 1; static_for<0, reduce_dims.size(), 1>{}( [&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; }); auto f = [&](auto linear_kept_idx) { ComputeDataType v_acc = reduce_op.template GetIdentityValue(); // Convert linear kept index to multi-dimensional kept indices std::vector kept_indices(kept_dim.size()); index_t temp_kept = linear_kept_idx; static_for<0, kept_dim.size(), 1>{}([&](auto i) { constexpr auto dim_idx = kept_dim.size() - 1 - i; constexpr auto dim = kept_dim.at(dim_idx); const auto len = x_lengths[dim]; kept_indices[dim_idx] = temp_kept % len; temp_kept /= len; }); for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx) { // Convert linear reduce index to multi-dimensional reduce indices std::vector reduce_indices(reduce_dims.size()); index_t temp_reduce = reduce_idx; static_for<0, reduce_dims.size(), 1>{}([&](auto i) { constexpr auto dim_idx = reduce_dims.size() - 1 - i; constexpr auto dim = reduce_dims.at(dim_idx); const auto len = x_lengths[dim]; reduce_indices[dim_idx] = temp_reduce % len; temp_reduce /= len; }); // Build full input tensor indices by combining kept and reduce indices std::vector full_indices(x_lengths.size(), 0); static_for<0, kept_dim.size(), 1>{}( [&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; }); static_for<0, reduce_dims.size(), 1>{}( [&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; }); // Calculate flat input tensor index index_t flat_x_idx = 0; for(size_t d = 0; d < full_indices.size(); ++d) { flat_x_idx += full_indices[d] * x_strides[d]; } const auto v_a = type_convert(x_tensor.mData[flat_x_idx]); v_acc = reduce_op(v_acc, v_a); } // Calculate output tensor index using kept indices and output strides // The output tensor has the same structure as the kept dimensions index_t flat_y_idx = 0; static_for<0, kept_dim.size(), 1>{}( [&](auto i) { flat_y_idx += kept_indices[i] * y_strides[i]; }); y_tensor.mData[flat_y_idx] = type_convert(v_acc); }; make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency()); } } // namespace ck_tile