mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Add tensor structure to wrapper (#1098)
* Add tensor structure to wrapper
* update changelog
* Fix names
* Comment fixes
[ROCm/composable_kernel commit: 07092d68f0]
This commit is contained in:
335
include/ck/wrapper/utils/layout_utils.hpp
Normal file
335
include/ck/wrapper/utils/layout_utils.hpp
Normal file
@@ -0,0 +1,335 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/sequence_helper.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declaration
|
||||
template <typename Shape, typename Strides>
|
||||
struct Layout;
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
/// @endcond
|
||||
|
||||
// make_*
|
||||
/**
|
||||
* \brief Make layout function.
|
||||
*
|
||||
* \tparam Shape Shape for layout.
|
||||
* \tparam Strides Strides for layout.
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
|
||||
const Strides& strides)
|
||||
{
|
||||
return Layout<Shape, Strides>(shape, strides);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Make layout function with packed strides
|
||||
* (column-major).
|
||||
*
|
||||
* \tparam Shape Shape for layout.
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr Layout<Shape, Tuple<>> make_layout(const Shape& shape)
|
||||
{
|
||||
return Layout<Shape, Tuple<>>(shape);
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
// get
|
||||
// Get dim (could be returned from get with empty Idxs)
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr get(const T& dim)
|
||||
{
|
||||
return dim;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get element from tuple (Shape/Strides/Idxs).
|
||||
*
|
||||
* \tparam idx Index to lookup.
|
||||
* \param tuple Tuple to lookup.
|
||||
* \return Requsted element.
|
||||
*/
|
||||
template <index_t idx, typename... Dims>
|
||||
__host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
|
||||
{
|
||||
return tuple.At(Number<idx>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get sub layout.
|
||||
*
|
||||
* \tparam idx Index to lookup.
|
||||
* \param layout Layout to create sub layout.
|
||||
* \return Requsted sub layout.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto& new_shape = get<idx>(shape);
|
||||
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
|
||||
"Shape of sub layout must be tuple");
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
// If stride not passed, create without strides
|
||||
return make_layout(new_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& strides = layout.GetStrides();
|
||||
const auto& new_strides = get<idx>(strides);
|
||||
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
|
||||
"Strides of sub layout must be tuple");
|
||||
return make_layout(new_shape, new_strides);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Hierarchical get.
|
||||
*
|
||||
* \tparam Idxs Indexes to lookup.
|
||||
* \param elem Element to lookup.
|
||||
* \return Requsted element.
|
||||
*/
|
||||
template <index_t Idx, index_t... Idxs, typename T>
|
||||
__host__ __device__ constexpr auto get(const T& elem)
|
||||
{
|
||||
return get<Idxs...>(get<Idx>(elem));
|
||||
}
|
||||
|
||||
// size
|
||||
// Get dim size (could be returned from get function)
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr size(const T& dim)
|
||||
{
|
||||
return dim;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Length get (product if tuple).
|
||||
*
|
||||
* \tparam idx Index to lookup.
|
||||
* \param layout Layout to get Shape of.
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.template GetLength<idx>();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Shape size (product of dims).
|
||||
*
|
||||
* \param shape Shape to lookup.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename... ShapeDims>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
|
||||
unrolled_shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Layout size (product of dims).
|
||||
*
|
||||
* \param layout Layout to calculate shape size.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetLengths();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Length get from tuple (product if tuple).
|
||||
*
|
||||
* \tparam idx Index to lookup.
|
||||
* \param tuple Tuple to lookup.
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename... Ts>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple)
|
||||
{
|
||||
return size(tuple.At(Number<idx>{}));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Hierarchical size.
|
||||
*
|
||||
* \tparam Idx First index to lookup (to avoid empty Idxs).
|
||||
* \tparam Idxs Next indexes to lookup.
|
||||
* \param elem Element to lookup.
|
||||
* \return Requsted element.
|
||||
*/
|
||||
template <index_t Idx, index_t... Idxs, typename T>
|
||||
__host__ __device__ constexpr auto size(const T& elem)
|
||||
{
|
||||
return size(get<Idx, Idxs...>(elem));
|
||||
}
|
||||
|
||||
// rank
|
||||
/**
|
||||
* \brief Get layout rank (num elements in shape).
|
||||
*
|
||||
* \param layout Layout to calculate rank.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return Shape::Size();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get tuple rank (num elements in tuple).
|
||||
* Return 1 if scalar passed.
|
||||
*
|
||||
* \param tuple Tuple to calculate rank.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <typename... Dims>
|
||||
__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& tuple)
|
||||
{
|
||||
return Tuple<Dims...>::Size();
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr index_t rank(const Number<IDim>&)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
__host__ __device__ constexpr index_t rank(const index_t&) { return 1; }
|
||||
|
||||
/**
|
||||
* \brief Hierarchical rank.
|
||||
*
|
||||
* \tparam Idxs Indexes to lookup.
|
||||
* \param elem Element to lookup.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <index_t... Idxs, typename T>
|
||||
__host__ __device__ constexpr auto rank(const T& elem)
|
||||
{
|
||||
return rank(get<Idxs...>(elem));
|
||||
}
|
||||
|
||||
// depth
|
||||
/**
|
||||
* \brief Get depth of the layout shape (return 0 if scalar).
|
||||
*
|
||||
* \param layout Layout to calculate depth.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
return TupleDepth(shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get depth of the tuple. (return 0 if scalar)
|
||||
*
|
||||
* \param tuple Tuple to calculate depth.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <typename... Dims>
|
||||
__host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
|
||||
{
|
||||
return TupleDepth(tuple);
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr index_t depth(const Number<IDim>&)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
__host__ __device__ constexpr index_t depth(const index_t&) { return 0; }
|
||||
|
||||
/**
|
||||
* \brief Hierarchical depth.
|
||||
*
|
||||
* \tparam Idxs Indexes to lookup.
|
||||
* \param elem Element to lookup.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <index_t... Idxs, typename T>
|
||||
__host__ __device__ constexpr auto depth(const T& elem)
|
||||
{
|
||||
return depth(get<Idxs...>(elem));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout strides.
|
||||
*
|
||||
* \param layout Layout to get strides from.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr const auto& stride(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetStrides();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout shape.
|
||||
*
|
||||
* \param layout Layout to get shape from.
|
||||
* \return Requsted shape.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr const auto& shape(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetShape();
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
290
include/ck/wrapper/utils/tensor_utils.hpp
Normal file
290
include/ck/wrapper/utils/tensor_utils.hpp
Normal file
@@ -0,0 +1,290 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
#include "ck/utility/dynamic_buffer.hpp"
|
||||
#include "ck/utility/amd_address_space.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
/**
|
||||
* \brief Memory type, allowed members:
|
||||
* - Generic,
|
||||
* - Global,
|
||||
* - LDS,
|
||||
* - SGPR,
|
||||
* - VGPR,
|
||||
*/
|
||||
using MemoryTypeEnum = AddressSpaceEnum;
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declarations
|
||||
template <typename Shape, typename Strides>
|
||||
struct Layout;
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors, // params for Register memory
|
||||
index_t ScalarPerVector // param for Register memory
|
||||
>
|
||||
|
||||
struct Tensor;
|
||||
|
||||
template <typename FromType, typename ToType>
|
||||
struct Slice
|
||||
{
|
||||
__host__ __device__ constexpr Slice() : from_(), to_() {}
|
||||
__host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto range(const T& dim) const
|
||||
{
|
||||
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
|
||||
is_same_v<T, index_t>)
|
||||
{
|
||||
assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range");
|
||||
if(to_ < 0)
|
||||
{
|
||||
return dim - from_ + to_ + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// workaround if one end of the interval is index_t and the second one is Number
|
||||
return static_cast<index_t>(to_) - static_cast<index_t>(from_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_),
|
||||
"Invalid range");
|
||||
if constexpr(to_ < 0)
|
||||
{
|
||||
return dim - from_ + to_ + Number<1>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return to_ - from_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsSlice() { return true; }
|
||||
|
||||
const FromType from_;
|
||||
const ToType to_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using is_slice = decltype(std::declval<T&>().IsSlice());
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Make tensor function.
|
||||
*
|
||||
* \tparam MemoryType Type of memory.
|
||||
* \param pointer Pointer to the memory.
|
||||
* \param layout Tensor layout.
|
||||
* \return Constructed tensor.
|
||||
*/
|
||||
template <MemoryTypeEnum MemoryType, typename ElementType, typename Shape, typename Strides>
|
||||
constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return Tensor<MemoryType, ElementType, Shape, Strides, 0 /*NumVectors*/, 0 /*ScalarPerVector*/>(
|
||||
pointer, layout);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Make SGPR or VGPR tensor function.
|
||||
*
|
||||
* \tparam MemoryType Type of memory.
|
||||
* \tparam NumVectors Number of vectors.
|
||||
* \tparam ScalarPerVector Scalars per vector.
|
||||
* \tparam ElementType Memory data type.
|
||||
* \param layout Tensor layout.
|
||||
* \return Constructed tensor.
|
||||
*/
|
||||
template <MemoryTypeEnum MemoryType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides>
|
||||
constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported");
|
||||
return Tensor<MemoryType, ElementType, Shape, Strides, NumVectors, ScalarPerVector>(layout);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor Layout.
|
||||
*
|
||||
* \param tensor Tensor to get layout of.
|
||||
* \return Requsted layout.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
layout(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return tensor.GetLayout();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Product of tensor shape dims.
|
||||
*
|
||||
* \tparam Idxs Indexes to access specific shape dim (optional).
|
||||
* \param tensor Tensor to get Shape of.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
size(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return size<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Rank of Shape tuple.
|
||||
*
|
||||
* \tparam Idxs Indexes to access specific shape dim (optional).
|
||||
* \param tensor Tensor to get rank of.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
rank(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return rank<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Depth of Shape tuple.
|
||||
*
|
||||
* \tparam Idxs Indexes to access specific shape dim (optional).
|
||||
* \param tensor Tensor to get depth of.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
depth(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return depth<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor strides.
|
||||
*
|
||||
* \param tensor Tensor to get strides from.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return stride(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor shape.
|
||||
*
|
||||
* \param tensor Tensor to get shape from.
|
||||
* \return Requsted shape.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
shape(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return shape(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get dim slice.
|
||||
*
|
||||
* \param from Beginning of the interval.
|
||||
* \param to End of the interval. (could be also negative to index from the end)
|
||||
* \return Requested slice. Could be used to create sliced tensor from other tensor.
|
||||
*/
|
||||
template <typename FromType, typename ToType>
|
||||
constexpr auto slice(const FromType from, const ToType to)
|
||||
{
|
||||
return Slice<FromType, ToType>(from, to);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get dim slice. (Assumed that from is equal to 1)
|
||||
*
|
||||
* \param to End of the interval. (could be also negative to index from the end)
|
||||
* \return Requested slice. Could be used to create sliced tensor from other tensor.
|
||||
*/
|
||||
template <typename ToType>
|
||||
constexpr auto slice(const ToType to)
|
||||
{
|
||||
if constexpr(is_same_v<ToType, index_t>)
|
||||
{
|
||||
return Slice<index_t, ToType>(0, to);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Slice<Number<0>, ToType>(Number<0>{}, to);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get whole dim slice (from = 0, to = -1).
|
||||
*
|
||||
* \return Requested slice. Could be used to create sliced tensor from other tensor.
|
||||
*/
|
||||
constexpr auto slice() { return Slice<Number<0>, Number<-1>>(Number<0>{}, Number<-1>{}); }
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user