mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add tensor partition and generic copy for ck wrapper (#1108)
* Add tensor partition and generic copy for ck wrapper * Update changelog * Stylistic fixes * Change shape/strides logic to descriptor transforms * Fixes * Fix client example * Fix comments
This commit is contained in:
@@ -14,11 +14,9 @@ namespace wrapper {
|
||||
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
|
||||
* (dynamic layout). It is possible to pass nested shapes
|
||||
* (e.g. ((4, 2), 2)), nested dimensions are merged.
|
||||
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
|
||||
* (dynamic layout). Stride tuple should be nested if shape tuple is
|
||||
* nested.
|
||||
* \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
struct Layout
|
||||
{
|
||||
private:
|
||||
@@ -31,7 +29,7 @@ struct Layout
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
// runtime layout
|
||||
return index_t(0);
|
||||
@@ -45,27 +43,6 @@ struct Layout
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
}
|
||||
|
||||
// Generate packed (column-major) strides if not passed
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
|
||||
unrolled_shape);
|
||||
}
|
||||
},
|
||||
Number<decltype(unrolled_shape)::Size()>{});
|
||||
}
|
||||
|
||||
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
|
||||
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
|
||||
// If tuple is element, then pass through (sequence with one element)
|
||||
@@ -207,33 +184,15 @@ struct Layout
|
||||
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
|
||||
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
|
||||
using DeducedStrides =
|
||||
std::conditional_t<is_same_v<Strides, Tuple<>>,
|
||||
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
|
||||
Strides>;
|
||||
using FlattenDescriptorType =
|
||||
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
|
||||
using Descriptor1dType =
|
||||
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>;
|
||||
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnnestedDescriptorType{}))>;
|
||||
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
|
||||
|
||||
template <typename... ShapeDims, typename... IdxDims>
|
||||
__host__ __device__ constexpr static auto
|
||||
TransformDesc(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx,
|
||||
const FlattenDescriptorType& naive_descriptor)
|
||||
const UnnestedDescriptorType& naive_descriptor)
|
||||
{
|
||||
if constexpr(Tuple<IdxDims...>::Size() == I1)
|
||||
{
|
||||
@@ -256,48 +215,33 @@ struct Layout
|
||||
}
|
||||
|
||||
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
|
||||
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>;
|
||||
Shape{}, DefaultIdxsTupleType{}, UnnestedDescriptorType{}))>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr auto GetElementSpaceSize() const
|
||||
{
|
||||
return flatten_descriptor_.GetElementSpaceSize();
|
||||
return unnested_descriptor_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
__host__ __device__ Layout() = delete;
|
||||
|
||||
/**
|
||||
* \brief Layout constructor.
|
||||
*
|
||||
* \param shape Shape for layout.
|
||||
* \param strides Strides for layout (optional if tensor is packed).
|
||||
* \param unnested_descriptor Descriptor
|
||||
*/
|
||||
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides)
|
||||
: flatten_descriptor_{}, shape_(shape), strides_(strides)
|
||||
__host__ __device__ constexpr Layout(const Shape& shape,
|
||||
const UnnestedDescriptorType& unnested_descriptor)
|
||||
: shape_(shape)
|
||||
{
|
||||
// Construct if runtime mode
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
|
||||
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
|
||||
unnested_descriptor_ = unnested_descriptor;
|
||||
descriptor_1d_ = MakeMerge1d(shape_, unnested_descriptor_);
|
||||
merged_nests_descriptor_ =
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Layout constructor (with default packed column-major strides).
|
||||
*
|
||||
* \param shape Shape for layout.
|
||||
*/
|
||||
__host__ __device__ constexpr Layout(const Shape& shape)
|
||||
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
|
||||
{
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
|
||||
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
|
||||
merged_nests_descriptor_ =
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,9 +254,9 @@ struct Layout
|
||||
template <typename Idxs>
|
||||
__host__ __device__ constexpr index_t operator()() const
|
||||
{
|
||||
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(),
|
||||
static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(),
|
||||
"Compiletime operator used on runtime layout.");
|
||||
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{}));
|
||||
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{}));
|
||||
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
|
||||
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
|
||||
}
|
||||
@@ -339,7 +283,7 @@ struct Layout
|
||||
else
|
||||
{
|
||||
// Custom index, need to transform descriptor
|
||||
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_);
|
||||
const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_);
|
||||
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
|
||||
}
|
||||
}
|
||||
@@ -351,7 +295,7 @@ struct Layout
|
||||
* \return Calculated size.
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr index_t GetLength() const
|
||||
__host__ __device__ constexpr auto GetLength() const
|
||||
{
|
||||
const auto elem = shape_.At(Number<IDim>{});
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
|
||||
@@ -371,7 +315,7 @@ struct Layout
|
||||
*
|
||||
* \return Calculated size.
|
||||
*/
|
||||
__host__ __device__ constexpr index_t GetLengths() const
|
||||
__host__ __device__ constexpr auto GetLengths() const
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape_);
|
||||
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
|
||||
@@ -385,13 +329,6 @@ struct Layout
|
||||
*/
|
||||
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
|
||||
|
||||
/**
|
||||
* \brief Strides getter.
|
||||
*
|
||||
* \return Strides.
|
||||
*/
|
||||
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }
|
||||
|
||||
/**
|
||||
* \brief Get default lengths (tuple filled with Shape length elements).
|
||||
*
|
||||
@@ -417,17 +354,26 @@ struct Layout
|
||||
*
|
||||
* \return Default descriptor.
|
||||
*/
|
||||
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor()
|
||||
__host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const
|
||||
{
|
||||
return merged_nests_descriptor_;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get unnested descriptor (with unrolled dims)
|
||||
*
|
||||
* \return Flatten descriptor.
|
||||
*/
|
||||
__host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const
|
||||
{
|
||||
return unnested_descriptor_;
|
||||
}
|
||||
|
||||
private:
|
||||
FlattenDescriptorType flatten_descriptor_;
|
||||
UnnestedDescriptorType unnested_descriptor_;
|
||||
Descriptor1dType descriptor_1d_;
|
||||
MergedNestsDescriptorType merged_nests_descriptor_;
|
||||
const Shape shape_;
|
||||
const DeducedStrides strides_;
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
Reference in New Issue
Block a user