Add tensor partition and generic copy for ck wrapper (#1108)

* Add tensor partition and generic copy for ck wrapper

* Update changelog

* Stylistic fixes

* Change shape/strides logic to descriptor transforms

* Fixes

* Fix client example

* Fix comments
This commit is contained in:
Bartłomiej Kocot
2024-01-03 01:10:57 +01:00
committed by GitHub
parent b268f273de
commit 4234b3a691
14 changed files with 940 additions and 306 deletions

View File

@@ -14,11 +14,9 @@ namespace wrapper {
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). Stride tuple should be nested if shape tuple is
* nested.
* \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims.
*/
template <typename Shape, typename Strides>
template <typename Shape, typename UnnestedDescriptorType>
struct Layout
{
private:
@@ -31,7 +29,7 @@ struct Layout
{
return generate_tuple(
[&](auto) {
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
{
// runtime layout
return index_t(0);
@@ -45,27 +43,6 @@ struct Layout
Number<Tuple<Ts...>::Size()>{});
}
// Generate packed (column-major) strides if not passed
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
if constexpr(i.value == 0)
{
return I1;
}
else
{
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
},
Number<decltype(unrolled_shape)::Size()>{});
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
@@ -207,33 +184,15 @@ struct Layout
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
Strides>;
using FlattenDescriptorType =
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>;
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnnestedDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto
TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx,
const FlattenDescriptorType& naive_descriptor)
const UnnestedDescriptorType& naive_descriptor)
{
if constexpr(Tuple<IdxDims...>::Size() == I1)
{
@@ -256,48 +215,33 @@ struct Layout
}
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>;
Shape{}, DefaultIdxsTupleType{}, UnnestedDescriptorType{}))>;
public:
__host__ __device__ constexpr auto GetElementSpaceSize() const
{
return flatten_descriptor_.GetElementSpaceSize();
return unnested_descriptor_.GetElementSpaceSize();
}
__host__ __device__ Layout() = delete;
/**
* \brief Layout constructor.
*
* \param shape Shape for layout.
* \param strides Strides for layout (optional if tensor is packed).
* \param unnested_descriptor Descriptor
*/
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides)
: flatten_descriptor_{}, shape_(shape), strides_(strides)
__host__ __device__ constexpr Layout(const Shape& shape,
const UnnestedDescriptorType& unnested_descriptor)
: shape_(shape)
{
// Construct if runtime mode
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
{
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
unnested_descriptor_ = unnested_descriptor;
descriptor_1d_ = MakeMerge1d(shape_, unnested_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
}
}
/**
* \brief Layout constructor (with default packed column-major strides).
*
* \param shape Shape for layout.
*/
__host__ __device__ constexpr Layout(const Shape& shape)
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
{
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
{
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_);
}
}
@@ -310,9 +254,9 @@ struct Layout
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(),
static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(),
"Compiletime operator used on runtime layout.");
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{}));
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
@@ -339,7 +283,7 @@ struct Layout
else
{
// Custom index, need to transform descriptor
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_);
const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
}
}
@@ -351,7 +295,7 @@ struct Layout
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const
__host__ __device__ constexpr auto GetLength() const
{
const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
@@ -371,7 +315,7 @@ struct Layout
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLengths() const
__host__ __device__ constexpr auto GetLengths() const
{
const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
@@ -385,13 +329,6 @@ struct Layout
*/
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
/**
* \brief Strides getter.
*
* \return Strides.
*/
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }
/**
* \brief Get default lengths (tuple filled with Shape length elements).
*
@@ -417,17 +354,26 @@ struct Layout
*
* \return Default descriptor.
*/
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor()
__host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const
{
return merged_nests_descriptor_;
}
/**
* \brief Get unnested descriptor (with unrolled dims)
*
* \return Flatten descriptor.
*/
__host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const
{
return unnested_descriptor_;
}
private:
FlattenDescriptorType flatten_descriptor_;
UnnestedDescriptorType unnested_descriptor_;
Descriptor1dType descriptor_1d_;
MergedNestsDescriptorType merged_nests_descriptor_;
const Shape shape_;
const DeducedStrides strides_;
};
} // namespace wrapper