Add tensor structure to wrapper (#1098)

* Add tensor structure to wrapper

* update changelog

* Fix names

* Comment fixes
This commit is contained in:
Bartłomiej Kocot
2023-12-15 12:45:08 +01:00
committed by GitHub
parent efaf31061a
commit 07092d68f0
9 changed files with 1020 additions and 88 deletions

View File

@@ -3,7 +3,7 @@
#pragma once
#include "ck/wrapper/layout_utils.hpp"
#include "ck/wrapper/utils/layout_utils.hpp"
namespace ck {
namespace wrapper {
@@ -25,6 +25,26 @@ struct Layout
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
// Generate default idxs tuple (idx with all merged nested shapes)
template <typename... Ts>
__host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple<Ts...>&)
{
return generate_tuple(
[&](auto) {
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
{
// runtime layout
return index_t(0);
}
else
{
// compiletime layout
return I0;
}
},
Number<Tuple<Ts...>::Size()>{});
}
// Generate packed (column-major) strides if not passed
template <typename... Ts>
__host__ __device__ constexpr static auto
@@ -131,7 +151,7 @@ struct Layout
template <typename... ShapeDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
DescriptorToMerge& desc)
const DescriptorToMerge& desc)
{
// Reverse each element in tuple
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
@@ -144,7 +164,7 @@ struct Layout
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}
// Merge nested shape dims. Merge nested shape dims when idx is also nested.
// Merge nested shape dims when corresponding index is also nested.
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
@@ -187,14 +207,38 @@ struct Layout
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
Strides>;
using FlattenDescriptorType =
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr auto TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx) const
__host__ __device__ constexpr static auto
TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx,
const FlattenDescriptorType& naive_descriptor)
{
if constexpr(Tuple<IdxDims...>::Size() == I1)
{
// 1d idx path
return MakeMerge1d(shape, descriptor_);
return MakeMerge1d(shape, naive_descriptor);
}
else
{
@@ -207,56 +251,53 @@ struct Layout
// Unroll while IdxDims is nested
const auto aligned_shape = AlignShapeToIdx(shape, idx);
// Transform correct form of shape
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_);
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor);
}
}
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>;
public:
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
Strides>;
using NaiveDescriptorType =
remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, DeducedStrides{}))>;
__host__ __device__ constexpr auto GetElementSpaceSize() const
{
return flatten_descriptor_.GetElementSpaceSize();
}
__host__ __device__ Layout() = delete;
/**
* \brief Layout constructor.
*
* \param shape Shape for layout.
* \param strides Strides for layout (optional if tensor is packed).
* \return Layout object.
*/
__host__ __device__ Layout() = delete;
__host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{}
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides)
: flatten_descriptor_{}, shape_(shape), strides_(strides)
{
// Construct if runtime mode
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
{
shape_ = shape;
strides_ = strides;
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
}
}
__host__ __device__ Layout(const Shape& shape) : descriptor_{}
/**
* \brief Layout constructor (with default packed column-major strides).
*
* \param shape Shape for layout.
*/
__host__ __device__ constexpr Layout(const Shape& shape)
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
{
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
{
shape_ = shape;
strides_ = GenerateColumnMajorPackedStrides(shape_);
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
}
}
@@ -269,7 +310,9 @@ struct Layout
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(),
"Compiletime operator used on runtime layout.");
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
@@ -283,9 +326,22 @@ struct Layout
template <typename... Ts>
__host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
{
// Static to construct transformed_desc only once
static const auto transformed_desc = TransformDesc(shape_, Idx);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
{
// if 1d access
return descriptor_1d_.CalculateOffset(Idx);
}
else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
{
// if Shape::Size() access (merged nested shapes)
return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
}
else
{
// Custom index, need to transform descriptor
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
}
}
/**
@@ -327,19 +383,51 @@ struct Layout
*
* \return Shape.
*/
__host__ __device__ constexpr Shape GetShape() const { return shape_; }
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
/**
* \brief Strides getter.
*
* \return Strides.
*/
__host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; }
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }
/**
* \brief Get default lengths (tuple filled with Shape length elements).
*
* \return Default lengths.
*/
__host__ __device__ constexpr auto GetDefaultLengthsTuple() const
{
return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{});
}
/**
* \brief Get default start idx (tuple filled with 0s of the same size as Shape).
*
* \return Default start idx.
*/
__host__ __device__ constexpr auto GetDefaultStartIdxs() const
{
return GenerateDefaultIdxsTuple(shape_);
}
/**
* \brief Get default descriptor (with the same size as Shape)
*
* \return Default descriptor.
*/
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor()
{
return merged_nests_descriptor_;
}
private:
NaiveDescriptorType descriptor_;
Shape shape_;
DeducedStrides strides_;
FlattenDescriptorType flatten_descriptor_;
Descriptor1dType descriptor_1d_;
MergedNestsDescriptorType merged_nests_descriptor_;
const Shape shape_;
const DeducedStrides strides_;
};
} // namespace wrapper

View File

@@ -0,0 +1,314 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "utils/tensor_utils.hpp"
#include "utils/layout_utils.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
*
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component).
* \tparam Strides Tensor strides (layout component).
* \tparam NumVectors Number of vectors (only for VGPR, SGPR).
* \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR).
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors, // param for Register memory
index_t ScalarPerVector // param for Register memory
>
struct Tensor
{
private:
// Check if Tuple contains Slice object
template <typename T>
constexpr static bool IsSlicing(T&&)
{
return is_detected<is_slice, T>::value;
}
template <typename... Ts>
constexpr static bool IsSlicing(Tuple<Ts...>&&)
{
return (IsSlicing(Ts{}) || ...);
}
// Calculate first index of new tensor after slice
// It is needed to calculate offset for new tensor
template <typename... Ts>
constexpr auto GetStartIdxForSlicedTensor(const Tuple<Ts...>& idx) const
{
const auto start_idx_for_sliced_tensor = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return GetStartIdxForSlicedTensor(idx.At(num_i));
}
else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if slice, return the beginning of the interval
return idx.At(num_i).from_;
}
else
{
// if one dim selected
return idx.At(num_i);
}
},
Number<Tuple<Ts...>::Size()>{});
return start_idx_for_sliced_tensor;
}
// Calculate new tensor shape after slice
template <typename... Ts, typename ShapeTmpType>
constexpr auto GetShapeFromSlicedTensor(const Tuple<Ts...>& idx,
const ShapeTmpType& shape) const
{
// Pack each value in tuple to remove empty tuples after generation
auto new_shape = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
if constexpr(!IsSlicing(tuple_element_t<i.value, Tuple<Ts...>>{}))
{
// if tuple does not have any slice then we can remove dimension
return Tuple<>{};
}
else
{
// if tuple then recurrence
return make_tuple(GetShapeFromSlicedTensor(idx.At(num_i), shape.At(num_i)));
}
}
else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// calculate new dimension
const auto& dim = size(shape.At(num_i));
const auto val = idx.At(num_i).range(dim);
return make_tuple(val);
}
else
{
// remove dimension for just value
return Tuple<>{};
}
},
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple<0, 1>(new_shape);
}
template <typename... Ts, typename StridesTmpType>
constexpr auto GetStridesFromSlicedTensor(const Tuple<Ts...>& idx,
const StridesTmpType& strides) const
{
// Pack each value in tuple to remove empty tuples after generation
auto new_strides = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
if constexpr(!IsSlicing(tuple_element_t<i.value, Tuple<Ts...>>{}))
{
// if tuple does not have any slice then we can remove dimension
return Tuple<>{};
}
else
{
// if tuple then recurrence
return make_tuple(
GetStridesFromSlicedTensor(idx.At(num_i), strides.At(num_i)));
}
}
else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// Stride will be the same
return make_tuple(strides.At(num_i));
}
else
{
// remove dimension for just value
return Tuple<>{};
}
},
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple<0, 1>(new_strides);
}
public:
using ElementSpaceSize = decltype(Layout<Shape, Strides>{
Shape{}, Strides{}}.GetElementSpaceSize()); // SpaceSize type for buffer
using TensorElementType = ElementType; // DataType
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
BufferAddressSpace == MemoryTypeEnum ::Vgpr);
__host__ __device__ Tensor() = delete;
__host__ __device__ Tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
: layout_(layout),
buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize()))
{
}
__host__ __device__ Tensor(const Layout<Shape, Strides>& layout) : layout_(layout)
{
static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
}
__host__ __device__ constexpr const Layout<Shape, Strides>& GetLayout() const
{
return layout_;
}
// Getter for new sliced tensor
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator[](const Tuple<Ts...>& idx) const
{
static_assert(IsDynamicBuffer, "Register slice is not supported");
// Calculate offset based on first idx for new tensor
const index_t offset = layout_(GetStartIdxForSlicedTensor(idx));
auto new_shape = GetShapeFromSlicedTensor(idx, layout_.GetShape());
if constexpr(is_same_v<Strides, Tuple<>>)
{
auto new_layout = make_layout(new_shape);
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout);
}
else
{
auto new_strides = GetStridesFromSlicedTensor(idx, layout_.GetStrides());
auto new_layout = make_layout(new_shape, new_strides);
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout);
}
}
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator()(const Tuple<Ts...>& idx) const
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<IsSlicing(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ auto operator()(Idxs... idxs) const
{
return this->operator[](make_tuple(idxs...));
}
// Getter for the const value
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const
{
if constexpr(IsDynamicBuffer)
{
const index_t offset = layout_(idx);
return buffer_[offset];
}
else
{
if constexpr(is_same_v<Strides, Tuple<>>)
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
return buffer_[Number<offset>{}];
}
else
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
return buffer_[Number<offset>{}];
}
}
}
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!IsSlicing(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(Idxs... idxs) const
{
return this->operator[](make_tuple(idxs...));
}
// Getter for the value reference
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx)
{
if constexpr(IsDynamicBuffer)
{
const index_t offset = layout_(idx);
return buffer_(offset);
}
else
{
if constexpr(is_same_v<Strides, Tuple<>>)
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
return buffer_(Number<offset>{});
}
else
{
constexpr index_t offset =
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
return buffer_(Number<offset>{});
}
}
}
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!IsSlicing(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ ElementType& operator()(Idxs... idxs)
{
return this->operator[](make_tuple(idxs...));
}
__host__ __device__ constexpr auto GetDefaultDescriptor()
{
return layout_.GetDefaultDescriptor();
}
private:
using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
ElementType,
ElementSpaceSize,
true /*InvalidElementUseNumericalZeroValue*/>;
using StaticBufferType =
StaticBufferTupleOfVector<BufferAddressSpace,
ElementType,
NumVectors,
ScalarPerVector,
true /*InvalidElementUseNumericalZeroValue*/>;
// If register use static buffer, else use dynamic buffer
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
const Layout<Shape, Strides> layout_;
Buffer buffer_;
};
} // namespace wrapper
} // namespace ck

View File

@@ -22,7 +22,7 @@ namespace wrapper {
// Disable from doxygen docs generation
/// @cond
// forward declaration
template <typename Shape, typename Strides = Tuple<>>
template <typename Shape, typename Strides>
struct Layout;
template <typename T>
@@ -52,13 +52,23 @@ __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& sh
* \return Constructed layout.
*/
template <typename Shape>
__host__ __device__ constexpr Layout<Shape> make_layout(const Shape& shape)
__host__ __device__ constexpr Layout<Shape, Tuple<>> make_layout(const Shape& shape)
{
return Layout<Shape>(shape);
return Layout<Shape, Tuple<>>(shape);
}
// Layout helpers
// get
// Get dim (could be returned from get with empty Idxs)
/**
* \private
*/
template <typename T>
__host__ __device__ T constexpr get(const T& dim)
{
return dim;
}
/**
* \brief Get element from tuple (Shape/Strides/Idxs).
*
@@ -82,7 +92,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
template <index_t idx, typename Shape, typename Strides>
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
{
const auto new_shape = get<idx>(layout.GetShape());
const auto& shape = layout.GetShape();
const auto& new_shape = get<idx>(shape);
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
"Shape of sub layout must be tuple");
if constexpr(is_same_v<Strides, Tuple<>>)
@@ -92,7 +103,8 @@ __host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
}
else
{
const auto new_strides = get<idx>(layout.GetStrides());
const auto& strides = layout.GetStrides();
const auto& new_strides = get<idx>(strides);
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
"Strides of sub layout must be tuple");
return make_layout(new_shape, new_strides);
@@ -113,11 +125,21 @@ __host__ __device__ constexpr auto get(const T& elem)
}
// size
// Get dim size (could be returned from get function)
/**
* \private
*/
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
{
return dim;
}
/**
* \brief Length get (product if tuple).
*
* \tparam idx Index to lookup.
* \param layout Layout to get Shape.
* \param layout Layout to get Shape of.
* \return Requsted length.
*/
template <index_t idx, typename Shape, typename Strides>
@@ -140,16 +162,6 @@ __host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
unrolled_shape);
}
// Get dim size (could be returned from get function)
/**
* \private
*/
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
{
return dim;
}
/**
* \brief Layout size (product of dims).
*
@@ -178,14 +190,15 @@ __host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple)
/**
* \brief Hierarchical size.
*
* \tparam Idxs Indexes to lookup.
* \tparam Idx First index to lookup (to avoid empty Idxs).
* \tparam Idxs Next indexes to lookup.
* \param elem Element to lookup.
* \return Requsted element.
*/
template <index_t... Idxs, typename T>
template <index_t Idx, index_t... Idxs, typename T>
__host__ __device__ constexpr auto size(const T& elem)
{
return size(get<Idxs...>(elem));
return size(get<Idx, Idxs...>(elem));
}
// rank
@@ -251,7 +264,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout)
{
return TupleDepth(layout.GetShape());
const auto& shape = layout.GetShape();
return TupleDepth(shape);
}
/**
@@ -296,11 +310,11 @@ __host__ __device__ constexpr auto depth(const T& elem)
/**
* \brief Get Layout strides.
*
* \param layout Layout to get strides.
* \param layout Layout to get strides from.
* \return Requsted strides.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto stride(const Layout<Shape, Strides>& layout)
__host__ __device__ constexpr const auto& stride(const Layout<Shape, Strides>& layout)
{
return layout.GetStrides();
}
@@ -308,11 +322,11 @@ __host__ __device__ constexpr auto stride(const Layout<Shape, Strides>& layout)
/**
* \brief Get Layout shape.
*
* \param layout Layout to get shape.
* \param layout Layout to get shape from.
* \return Requsted shape.
*/
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto shape(const Layout<Shape, Strides>& layout)
__host__ __device__ constexpr const auto& shape(const Layout<Shape, Strides>& layout)
{
return layout.GetShape();
}

View File

@@ -0,0 +1,290 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/amd_address_space.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Memory type, allowed members:
* - Generic,
* - Global,
* - LDS,
* - SGPR,
* - VGPR,
*/
using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template <typename Shape, typename Strides>
struct Layout;
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors, // params for Register memory
index_t ScalarPerVector // param for Register memory
>
struct Tensor;
template <typename FromType, typename ToType>
struct Slice
{
__host__ __device__ constexpr Slice() : from_(), to_() {}
__host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {}
template <typename T>
__host__ __device__ constexpr auto range(const T& dim) const
{
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
is_same_v<T, index_t>)
{
assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range");
if(to_ < 0)
{
return dim - from_ + to_ + 1;
}
else
{
// workaround if one end of the interval is index_t and the second one is Number
return static_cast<index_t>(to_) - static_cast<index_t>(from_);
}
}
else
{
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_),
"Invalid range");
if constexpr(to_ < 0)
{
return dim - from_ + to_ + Number<1>{};
}
else
{
return to_ - from_;
}
}
}
__host__ __device__ static constexpr bool IsSlice() { return true; }
const FromType from_;
const ToType to_;
};
template <typename T>
using is_slice = decltype(std::declval<T&>().IsSlice());
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
/// @endcond
/**
* \brief Make tensor function.
*
* \tparam MemoryType Type of memory.
* \param pointer Pointer to the memory.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType, typename ElementType, typename Shape, typename Strides>
constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
{
return Tensor<MemoryType, ElementType, Shape, Strides, 0 /*NumVectors*/, 0 /*ScalarPerVector*/>(
pointer, layout);
}
/**
* \brief Make SGPR or VGPR tensor function.
*
* \tparam MemoryType Type of memory.
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType,
index_t NumVectors,
index_t ScalarPerVector,
typename ElementType,
typename Shape,
typename Strides>
constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
{
static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported");
return Tensor<MemoryType, ElementType, Shape, Strides, NumVectors, ScalarPerVector>(layout);
}
/**
* \brief Get Tensor Layout.
*
* \param tensor Tensor to get layout of.
* \return Requsted layout.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
layout(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return tensor.GetLayout();
}
/**
* \brief Product of tensor shape dims.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get Shape of.
* \return Requsted size.
*/
template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
size(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return size<Idxs...>(tensor.GetLayout());
}
/**
* \brief Rank of Shape tuple.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get rank of.
* \return Requsted rank.
*/
template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
rank(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return rank<Idxs...>(tensor.GetLayout());
}
/**
* \brief Depth of Shape tuple.
*
* \tparam Idxs Indexes to access specific shape dim (optional).
* \param tensor Tensor to get depth of.
* \return Requsted depth.
*/
template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
depth(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return depth<Idxs...>(tensor.GetLayout());
}
/**
* \brief Get Tensor strides.
*
* \param tensor Tensor to get strides from.
* \return Requsted strides.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return stride(tensor.GetLayout());
}
/**
* \brief Get Tensor shape.
*
* \param tensor Tensor to get shape from.
* \return Requsted shape.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
shape(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return shape(tensor.GetLayout());
}
/**
* \brief Get dim slice.
*
* \param from Beginning of the interval.
* \param to End of the interval. (could be also negative to index from the end)
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
template <typename FromType, typename ToType>
constexpr auto slice(const FromType from, const ToType to)
{
return Slice<FromType, ToType>(from, to);
}
/**
* \brief Get dim slice. (Assumed that from is equal to 1)
*
* \param to End of the interval. (could be also negative to index from the end)
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
template <typename ToType>
constexpr auto slice(const ToType to)
{
if constexpr(is_same_v<ToType, index_t>)
{
return Slice<index_t, ToType>(0, to);
}
else
{
return Slice<Number<0>, ToType>(Number<0>{}, to);
}
}
/**
* \brief Get whole dim slice (from = 0, to = -1).
*
* \return Requested slice. Could be used to create sliced tensor from other tensor.
*/
constexpr auto slice() { return Slice<Number<0>, Number<-1>>(Number<0>{}, Number<-1>{}); }
} // namespace wrapper
} // namespace ck