// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" #include #include #include namespace ck_tile { /* this will do permute + contiguous like functionality in pytorch */ template CK_TILE_HOST void reference_permute(const HostTensor& x, HostTensor& y, std::vector perm) { const auto x_len = x.mDesc.get_lengths(); const auto y_len = y.mDesc.get_lengths(); assert(x_len.size() == y_len.size()); index_t rank = x_len.size(); const auto x_elm = std::accumulate(x_len.begin(), x_len.end(), 1, std::multiplies()); const auto y_elm = std::accumulate(y_len.begin(), y_len.end(), 1, std::multiplies()); assert(x_elm == y_elm); (void)y_elm; auto f = [&](auto i_element) { std::vector y_coord = [&]() { std::vector tmp(rank, 0); size_t r = i_element; for(index_t i = rank - 1; i >= 0; i--) { tmp[i] = r % y_len[i]; r = r / y_len[i]; } return tmp; }(); std::vector x_coord = [&]() { std::vector tmp(rank, 0); for(index_t i = 0; i < rank; i++) { tmp[perm[i]] = y_coord[i]; } return tmp; }(); // do permute y(y_coord) = x(x_coord); }; make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency()); } template CK_TILE_HOST auto reference_permute(const HostTensor& x, std::vector perm) { auto x_shape = x.get_lengths(); ck_tile::index_t rank = perm.size(); std::vector y_shape = [&]() { std::vector tmp(rank, 0); for(int i = 0; i < static_cast(rank); i++) { tmp[i] = x_shape[perm[i]]; } return tmp; }(); HostTensor y(y_shape); reference_permute(x, y, perm); return y; } } // namespace ck_tile