mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Add tensor partition and generic copy for ck wrapper (#1108)
* Add tensor partition and generic copy for ck wrapper * Update changelog * Stylistic fixes * Change shape/strides logic to descriptor transforms * Fixes * Fix client example * Fix comments
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,11 +22,57 @@ namespace wrapper {
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declaration
|
||||
template <typename Shape, typename Strides>
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
struct Layout;
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
namespace {
|
||||
// Generate packed (column-major) strides if not passed
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return Number<1>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return TupleReduce<Number<0>{}.value, i.value>([](auto x, auto y) { return x * y; },
|
||||
unrolled_shape);
|
||||
}
|
||||
},
|
||||
Number<decltype(unrolled_shape)::Size()>{});
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
if constexpr(is_same_v<LayoutStrides, Tuple<>>)
|
||||
{
|
||||
// if not passed, then generate
|
||||
const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/// @endcond
|
||||
|
||||
// make_*
|
||||
@@ -38,10 +84,10 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
|
||||
const Strides& strides)
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
|
||||
{
|
||||
return Layout<Shape, Strides>(shape, strides);
|
||||
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, strides));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -52,9 +98,10 @@ __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& sh
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr Layout<Shape, Tuple<>> make_layout(const Shape& shape)
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape)
|
||||
{
|
||||
return Layout<Shape, Tuple<>>(shape);
|
||||
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, Tuple<>{}));
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
@@ -89,26 +136,51 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
|
||||
* \param layout Layout to create sub layout.
|
||||
* \return Requsted sub layout.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
|
||||
template <index_t idx, typename Shape, typename FlattenDesc>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto& new_shape = get<idx>(shape);
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto new_shape = get<idx>(shape);
|
||||
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
|
||||
"Shape of sub layout must be tuple");
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
// If stride not passed, create without strides
|
||||
return make_layout(new_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& strides = layout.GetStrides();
|
||||
const auto& new_strides = get<idx>(strides);
|
||||
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
|
||||
"Strides of sub layout must be tuple");
|
||||
return make_layout(new_shape, new_strides);
|
||||
}
|
||||
|
||||
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
|
||||
constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size();
|
||||
constexpr auto shape_offset = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size();
|
||||
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
// Compare Idx with shape
|
||||
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
|
||||
{
|
||||
// Remove dimension
|
||||
return make_freeze_transform(Number<0>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(unrolled_shape.At(i));
|
||||
}
|
||||
},
|
||||
Number<old_shape_dims>{});
|
||||
|
||||
const auto lower_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
|
||||
const auto upper_dims = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
|
||||
return Sequence<>{};
|
||||
|
||||
else
|
||||
{
|
||||
return Sequence<i.value - shape_offset>{};
|
||||
}
|
||||
},
|
||||
Number<old_shape_dims>{});
|
||||
|
||||
const auto& flatten_desc = layout.GetUnnestedDescriptor();
|
||||
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -142,8 +214,8 @@ __host__ __device__ T constexpr size(const T& dim)
|
||||
* \param layout Layout to get Shape of.
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
template <index_t idx, typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return layout.template GetLength<idx>();
|
||||
}
|
||||
@@ -155,7 +227,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename... ShapeDims>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
__host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
|
||||
@@ -168,8 +240,8 @@ __host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
* \param layout Layout to calculate shape size.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return layout.GetLengths();
|
||||
}
|
||||
@@ -182,7 +254,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename... Ts>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple)
|
||||
__host__ __device__ constexpr auto size(const Tuple<Ts...>& tuple)
|
||||
{
|
||||
return size(tuple.At(Number<idx>{}));
|
||||
}
|
||||
@@ -208,8 +280,9 @@ __host__ __device__ constexpr auto size(const T& elem)
|
||||
* \param layout Layout to calculate rank.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout<Shape, Strides>& layout)
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto
|
||||
rank([[maybe_unused]] const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return Shape::Size();
|
||||
}
|
||||
@@ -261,8 +334,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
|
||||
* \param layout Layout to calculate depth.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout)
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
return TupleDepth(shape);
|
||||
@@ -307,26 +380,14 @@ __host__ __device__ constexpr auto depth(const T& elem)
|
||||
return depth(get<Idxs...>(elem));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout strides.
|
||||
*
|
||||
* \param layout Layout to get strides from.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr const auto& stride(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetStrides();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout shape.
|
||||
*
|
||||
* \param layout Layout to get shape from.
|
||||
* \return Requsted shape.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr const auto& shape(const Layout<Shape, Strides>& layout)
|
||||
template <typename LayoutType>
|
||||
__host__ __device__ constexpr const auto& shape(const LayoutType& layout)
|
||||
{
|
||||
return layout.GetShape();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user