// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck/ck.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/number.hpp" #include "ck/utility/tuple.hpp" #include "ck/utility/tuple_helper.hpp" #include "ck/utility/dynamic_buffer.hpp" #include "ck/utility/amd_address_space.hpp" #include "ck/utility/multi_index.hpp" // Disable from doxygen docs generation /// @cond INTERNAL namespace ck { namespace wrapper { /// @endcond /** * \brief Memory type, allowed members: * - Generic, * - Global, * - Lds, * - Sgpr, * - Vgpr, */ using MemoryTypeEnum = AddressSpaceEnum; // Disable from doxygen docs generation /// @cond INTERNAL // forward declarations template struct Layout; template struct Tensor; template struct Slice { __host__ __device__ constexpr Slice() : from_(), to_() {} __host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {} /** * \brief Calculate slice range. * * \param dim Dimension size. * \return Slice range. */ template __host__ __device__ constexpr auto range(const T& dim) const { if constexpr(is_same_v || is_same_v || is_same_v, index_t>) { if(to_ < 0) { return dim - from_ + to_ + 1; } else { // workaround if one end of the interval is index_t and the second one is Number return static_cast(to_) - static_cast(from_); } } else { static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} && (ToType{} < 0 || ToType{} > FromType{}), "Invalid range"); if constexpr(ToType{} < 0) { return dim - from_ + to_ + Number<1>{}; } else { return to_ - from_; } } } __host__ __device__ static constexpr bool IsSlice() { return true; } const FromType from_; const ToType to_; }; template using is_slice = decltype(std::declval().IsSlice()); template using is_tuple = decltype(std::declval().IsTuple()); /// @endcond /** * \brief Make tensor function. * * \tparam MemoryType Type of memory. * \param pointer Pointer to the memory. * \param layout Tensor layout. * \return Constructed tensor. */ template constexpr auto make_tensor(ElementType* pointer, const Layout& layout) { return Tensor(pointer, layout); } /** * \brief Make SGPR or VGPR tensor function. * * \tparam MemoryType Type of memory. * \tparam ElementType Memory data type. * \return Constructed tensor. */ template constexpr auto make_register_tensor(const Layout& layout) { return Tensor(layout); } /** * \brief Clear tensor. (Only for Vpgr/Sgpr) * * \param tensor Tensor to be cleared. */ template __host__ __device__ void clear(Tensor& tensor) { static_assert( !Tensor::IsDynamicBuffer); return tensor.GetBuffer().Clear(); } /** * \brief Get Tensor Layout. * * \param tensor Tensor to get layout of. * \return Requsted layout. */ template __host__ __device__ constexpr const auto& layout(const Tensor& tensor) { return tensor.GetLayout(); } /** * \brief Product of tensor shape dims. * * \tparam Idxs Indexes to access specific shape dim (optional). * \param tensor Tensor to get Shape of. * \return Requsted size. */ template __host__ __device__ constexpr auto size(const Tensor& tensor) { return size(tensor.GetLayout()); } /** * \brief Rank of Shape tuple. * * \tparam Idxs Indexes to access specific shape dim (optional). * \param tensor Tensor to get rank of. * \return Requsted rank. */ template __host__ __device__ constexpr auto rank(const Tensor& tensor) { return rank(tensor.GetLayout()); } /** * \brief Depth of Shape tuple. * * \tparam Idxs Indexes to access specific shape dim (optional). * \param tensor Tensor to get depth of. * \return Requsted depth. */ template __host__ __device__ constexpr auto depth(const Tensor& tensor) { return depth(tensor.GetLayout()); } /** * \brief Get Tensor shape. * * \param tensor Tensor to get shape from. * \return Requsted shape. */ template __host__ __device__ constexpr const auto& shape(const Tensor& tensor) { return shape(tensor.GetLayout()); } /** * \brief Get dim slice. * * \param from Beginning of the interval. * \param to End of the interval. (could be also negative to index from the end) * \return Requested slice. Could be used to create sliced tensor from other tensor. */ template constexpr auto slice(const FromType from, const ToType to) { return Slice(from, to); } /** * \brief Get dim slice. (Assumed that from is equal to 1) * * \param to End of the interval. (could be also negative to index from the end) * \return Requested slice. Could be used to create sliced tensor from other tensor. */ template constexpr auto slice(const ToType to) { if constexpr(is_same_v) { return Slice(0, to); } else { return Slice, ToType>(Number<0>{}, to); } } /** * \brief Get whole dim slice (from = 0, to = -1). * * \return Requested slice. Could be used to create sliced tensor from other tensor. */ constexpr auto slice() { return Slice, Number<-1>>(Number<0>{}, Number<-1>{}); } } // namespace wrapper } // namespace ck