// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" #include #include #include #include #include namespace ck_tile { /* similiar to torch.topk() x (Tensor) – the input tensor. k (int) – the k in “top-k” dim (int, optional) – the dimension to sort along largest (bool, optional) – largest or smallest elements sorted (bool, optional) – elements in sorted order or not output: y_values y_indices https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h */ template CK_TILE_HOST void reference_topk(const HostTensor& x, HostTensor& y_values, HostTensor& y_indices, index_t k, index_t dim = -1, bool largest = true, bool sorted = true) { // rank must be the same index_t rank = x.get_num_of_dimension(); assert(static_cast(rank) == y_values.get_num_of_dimension()); assert(static_cast(rank) == y_indices.get_num_of_dimension()); assert(dim == -1 || dim < rank); index_t topk_dim = dim == -1 ? (rank - 1) : dim; index_t topk_src_len = x.get_length(topk_dim); auto x_len = x.get_lengths(); assert(k <= topk_src_len); assert(static_cast(k) == y_values.get_length(topk_dim) && static_cast(k) == y_indices.get_length(topk_dim)); index_t n_parallel = x.get_element_size() / topk_src_len; // clang-format off auto f = [&](auto i_element) { std::vector topk_coord = [&](){ std::vector t_(rank, 0); size_t r = i_element; for(index_t i = rank - 1; i >= 0; i--) { if(i == topk_dim) continue; // topk dim should be zero t_[i] = r % x_len[i]; r = r / x_len[i]; } return t_; }(); using elem_t = std::pair; std::vector q = [&](){ std::vector t_(topk_src_len); for(index_t i = 0; i < topk_src_len; i++) { auto c_ = topk_coord; c_[topk_dim] = i; t_[i].first = x(c_); t_[i].second = i; } return t_; }(); // run topk if(largest) { std::nth_element(q.begin(), q.begin() + k - 1, q.end(), [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; }); if(sorted) { std::sort(q.begin(), q.begin() + k - 1, [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; }); } } else { std::nth_element(q.begin(), q.begin() + k - 1, q.end(), [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; }); if(sorted) { std::sort(q.begin(), q.begin() + k - 1, [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; }); } } // write out for(index_t i = 0; i < k; i++) { auto c_ = topk_coord; c_[topk_dim] = i; y_values(c_) = q[i].first; y_indices(c_) = q[i].second; } }; // clang-format on make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency()); } // TODO: if using this method, the return tensor would be dense(no stride) template CK_TILE_HOST auto reference_topk(const HostTensor& x, index_t k, index_t dim = -1, bool largest = true, bool sorted = true) { auto lens = x.get_lengths(); index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim; assert(target_dim < lens.size()); assert(k <= lens[target_dim]); lens[target_dim] = k; HostTensor y_values(lens); HostTensor y_indices(lens); reference_topk(x, y_values, y_indices, k, dim, largest, sorted); return ck_tile::make_tuple(y_values, y_indices); } } // namespace ck_tile