// SPDX-License-Identifier: MIT // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/ck.hpp" #include "ck/utility/number.hpp" #include "ck/utility/tuple.hpp" #include "ck/utility/tuple_helper.hpp" #include "ck/utility/dynamic_buffer.hpp" #include "ck/utility/amd_address_space.hpp" namespace ck { namespace wrapper { /** * \brief Memory type, allowed members: * - Generic, * - Global, * - LDS, * - SGPR, * - VGPR, */ using MemoryTypeEnum = AddressSpaceEnum; // Disable from doxygen docs generation /// @cond // forward declarations template struct Layout; template struct Tensor; template struct Slice { __host__ __device__ constexpr Slice() : from_(), to_() {} __host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {} template __host__ __device__ constexpr auto range(const T& dim) const { if constexpr(is_same_v || is_same_v || is_same_v) { assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range"); if(to_ < 0) { return dim - from_ + to_ + 1; } else { // workaround if one end of the interval is index_t and the second one is Number return static_cast(to_) - static_cast(from_); } } else { static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_), "Invalid range"); if constexpr(to_ < 0) { return dim - from_ + to_ + Number<1>{}; } else { return to_ - from_; } } } __host__ __device__ static constexpr bool IsSlice() { return true; } const FromType from_; const ToType to_; }; template using is_slice = decltype(std::declval().IsSlice()); template using is_tuple = decltype(std::declval().IsTuple()); /// @endcond /** * \brief Make tensor function. * * \tparam MemoryType Type of memory. * \param pointer Pointer to the memory. * \param layout Tensor layout. * \return Constructed tensor. */ template constexpr auto make_tensor(ElementType* pointer, const Layout& layout) { return Tensor(pointer, layout); } /** * \brief Make SGPR or VGPR tensor function. * * \tparam MemoryType Type of memory. * \tparam NumVectors Number of vectors. * \tparam ScalarPerVector Scalars per vector. * \tparam ElementType Memory data type. * \return Constructed tensor. */ template constexpr auto make_register_tensor() { const auto layout = make_layout(make_tuple(Number{}), make_tuple(Number<1>{})); return Tensor>, std::remove_const_t>, NumVectors, ScalarPerVector>(layout); } /** * \brief Get Tensor Layout. * * \param tensor Tensor to get layout of. * \return Requsted layout. */ template __host__ __device__ constexpr const auto& layout(const Tensor& tensor) { return tensor.GetLayout(); } /** * \brief Product of tensor shape dims. * * \tparam Idxs Indexes to access specific shape dim (optional). * \param tensor Tensor to get Shape of. * \return Requsted size. */ template __host__ __device__ constexpr auto size(const Tensor& tensor) { return size(tensor.GetLayout()); } /** * \brief Rank of Shape tuple. * * \tparam Idxs Indexes to access specific shape dim (optional). * \param tensor Tensor to get rank of. * \return Requsted rank. */ template __host__ __device__ constexpr auto rank(const Tensor& tensor) { return rank(tensor.GetLayout()); } /** * \brief Depth of Shape tuple. * * \tparam Idxs Indexes to access specific shape dim (optional). * \param tensor Tensor to get depth of. * \return Requsted depth. */ template __host__ __device__ constexpr auto depth(const Tensor& tensor) { return depth(tensor.GetLayout()); } /** * \brief Get Tensor shape. * * \param tensor Tensor to get shape from. * \return Requsted shape. */ template __host__ __device__ constexpr const auto& shape(const Tensor& tensor) { return shape(tensor.GetLayout()); } /** * \brief Get dim slice. * * \param from Beginning of the interval. * \param to End of the interval. (could be also negative to index from the end) * \return Requested slice. Could be used to create sliced tensor from other tensor. */ template constexpr auto slice(const FromType from, const ToType to) { return Slice(from, to); } /** * \brief Get dim slice. (Assumed that from is equal to 1) * * \param to End of the interval. (could be also negative to index from the end) * \return Requested slice. Could be used to create sliced tensor from other tensor. */ template constexpr auto slice(const ToType to) { if constexpr(is_same_v) { return Slice(0, to); } else { return Slice, ToType>(Number<0>{}, to); } } /** * \brief Get whole dim slice (from = 0, to = -1). * * \return Requested slice. Could be used to create sliced tensor from other tensor. */ constexpr auto slice() { return Slice, Number<-1>>(Number<0>{}, Number<-1>{}); } } // namespace wrapper } // namespace ck