mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add tensor structure to wrapper (#1098)
* Add tensor structure to wrapper * update changelog * Fix names * Comment fixes
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/wrapper/layout_utils.hpp"
|
||||
#include "ck/wrapper/utils/layout_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
@@ -25,6 +25,26 @@ struct Layout
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
// Generate default idxs tuple (idx with all merged nested shapes)
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple<Ts...>&)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
// runtime layout
|
||||
return index_t(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// compiletime layout
|
||||
return I0;
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
}
|
||||
|
||||
// Generate packed (column-major) strides if not passed
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
@@ -131,7 +151,7 @@ struct Layout
|
||||
|
||||
template <typename... ShapeDims, typename DescriptorToMerge>
|
||||
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
|
||||
DescriptorToMerge& desc)
|
||||
const DescriptorToMerge& desc)
|
||||
{
|
||||
// Reverse each element in tuple
|
||||
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
|
||||
@@ -144,7 +164,7 @@ struct Layout
|
||||
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
// Merge nested shape dims. Merge nested shape dims when idx is also nested.
|
||||
// Merge nested shape dims when corresponding index is also nested.
|
||||
// Input desc shape: 2, 2, 2, 2, 2, 2
|
||||
// Example idx: 1, 1, 1, 1
|
||||
// Example shape: 2, (2, 2), 2, (2, 2)
|
||||
@@ -187,14 +207,38 @@ struct Layout
|
||||
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
|
||||
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
|
||||
using DeducedStrides =
|
||||
std::conditional_t<is_same_v<Strides, Tuple<>>,
|
||||
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
|
||||
Strides>;
|
||||
using FlattenDescriptorType =
|
||||
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
|
||||
using Descriptor1dType =
|
||||
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>;
|
||||
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
|
||||
|
||||
template <typename... ShapeDims, typename... IdxDims>
|
||||
__host__ __device__ constexpr auto TransformDesc(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx) const
|
||||
__host__ __device__ constexpr static auto
|
||||
TransformDesc(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx,
|
||||
const FlattenDescriptorType& naive_descriptor)
|
||||
{
|
||||
if constexpr(Tuple<IdxDims...>::Size() == I1)
|
||||
{
|
||||
// 1d idx path
|
||||
return MakeMerge1d(shape, descriptor_);
|
||||
return MakeMerge1d(shape, naive_descriptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -207,56 +251,53 @@ struct Layout
|
||||
// Unroll while IdxDims is nested
|
||||
const auto aligned_shape = AlignShapeToIdx(shape, idx);
|
||||
// Transform correct form of shape
|
||||
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_);
|
||||
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
|
||||
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>;
|
||||
|
||||
public:
|
||||
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
|
||||
using DeducedStrides =
|
||||
std::conditional_t<is_same_v<Strides, Tuple<>>,
|
||||
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
|
||||
Strides>;
|
||||
using NaiveDescriptorType =
|
||||
remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, DeducedStrides{}))>;
|
||||
__host__ __device__ constexpr auto GetElementSpaceSize() const
|
||||
{
|
||||
return flatten_descriptor_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
__host__ __device__ Layout() = delete;
|
||||
/**
|
||||
* \brief Layout constructor.
|
||||
*
|
||||
* \param shape Shape for layout.
|
||||
* \param strides Strides for layout (optional if tensor is packed).
|
||||
* \return Layout object.
|
||||
*/
|
||||
__host__ __device__ Layout() = delete;
|
||||
__host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{}
|
||||
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides)
|
||||
: flatten_descriptor_{}, shape_(shape), strides_(strides)
|
||||
{
|
||||
// Construct if runtime mode
|
||||
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
shape_ = shape;
|
||||
strides_ = strides;
|
||||
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
|
||||
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
|
||||
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
|
||||
merged_nests_descriptor_ =
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ Layout(const Shape& shape) : descriptor_{}
|
||||
/**
|
||||
* \brief Layout constructor (with default packed column-major strides).
|
||||
*
|
||||
* \param shape Shape for layout.
|
||||
*/
|
||||
__host__ __device__ constexpr Layout(const Shape& shape)
|
||||
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
|
||||
{
|
||||
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
shape_ = shape;
|
||||
strides_ = GenerateColumnMajorPackedStrides(shape_);
|
||||
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
|
||||
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
|
||||
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
|
||||
merged_nests_descriptor_ =
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,7 +310,9 @@ struct Layout
|
||||
template <typename Idxs>
|
||||
__host__ __device__ constexpr index_t operator()() const
|
||||
{
|
||||
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
|
||||
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(),
|
||||
"Compiletime operator used on runtime layout.");
|
||||
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{}));
|
||||
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
|
||||
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
|
||||
}
|
||||
@@ -283,9 +326,22 @@ struct Layout
|
||||
template <typename... Ts>
|
||||
__host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
|
||||
{
|
||||
// Static to construct transformed_desc only once
|
||||
static const auto transformed_desc = TransformDesc(shape_, Idx);
|
||||
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
|
||||
if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
|
||||
{
|
||||
// if 1d access
|
||||
return descriptor_1d_.CalculateOffset(Idx);
|
||||
}
|
||||
else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
|
||||
{
|
||||
// if Shape::Size() access (merged nested shapes)
|
||||
return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Custom index, need to transform descriptor
|
||||
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_);
|
||||
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -327,19 +383,51 @@ struct Layout
|
||||
*
|
||||
* \return Shape.
|
||||
*/
|
||||
__host__ __device__ constexpr Shape GetShape() const { return shape_; }
|
||||
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
|
||||
|
||||
/**
|
||||
* \brief Strides getter.
|
||||
*
|
||||
* \return Strides.
|
||||
*/
|
||||
__host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; }
|
||||
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }
|
||||
|
||||
/**
|
||||
* \brief Get default lengths (tuple filled with Shape length elements).
|
||||
*
|
||||
* \return Default lengths.
|
||||
*/
|
||||
__host__ __device__ constexpr auto GetDefaultLengthsTuple() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get default start idx (tuple filled with 0s of the same size as Shape).
|
||||
*
|
||||
* \return Default start idx.
|
||||
*/
|
||||
__host__ __device__ constexpr auto GetDefaultStartIdxs() const
|
||||
{
|
||||
return GenerateDefaultIdxsTuple(shape_);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get default descriptor (with the same size as Shape)
|
||||
*
|
||||
* \return Default descriptor.
|
||||
*/
|
||||
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor()
|
||||
{
|
||||
return merged_nests_descriptor_;
|
||||
}
|
||||
|
||||
private:
|
||||
NaiveDescriptorType descriptor_;
|
||||
Shape shape_;
|
||||
DeducedStrides strides_;
|
||||
FlattenDescriptorType flatten_descriptor_;
|
||||
Descriptor1dType descriptor_1d_;
|
||||
MergedNestsDescriptorType merged_nests_descriptor_;
|
||||
const Shape shape_;
|
||||
const DeducedStrides strides_;
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
Reference in New Issue
Block a user