mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add tensor partition and generic copy for ck wrapper (#1108)
* Add tensor partition and generic copy for ck wrapper * Update changelog * Stylistic fixes * Change shape/strides logic to descriptor transforms * Fixes * Fix client example * Fix comments
This commit is contained in:
@@ -14,11 +14,9 @@ namespace wrapper {
|
||||
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
|
||||
* (dynamic layout). It is possible to pass nested shapes
|
||||
* (e.g. ((4, 2), 2)), nested dimensions are merged.
|
||||
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
|
||||
* (dynamic layout). Stride tuple should be nested if shape tuple is
|
||||
* nested.
|
||||
* \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
struct Layout
|
||||
{
|
||||
private:
|
||||
@@ -31,7 +29,7 @@ struct Layout
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
// runtime layout
|
||||
return index_t(0);
|
||||
@@ -45,27 +43,6 @@ struct Layout
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
}
|
||||
|
||||
// Generate packed (column-major) strides if not passed
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
|
||||
unrolled_shape);
|
||||
}
|
||||
},
|
||||
Number<decltype(unrolled_shape)::Size()>{});
|
||||
}
|
||||
|
||||
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
|
||||
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
|
||||
// If tuple is element, then pass through (sequence with one element)
|
||||
@@ -207,33 +184,15 @@ struct Layout
|
||||
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
|
||||
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
|
||||
using DeducedStrides =
|
||||
std::conditional_t<is_same_v<Strides, Tuple<>>,
|
||||
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
|
||||
Strides>;
|
||||
using FlattenDescriptorType =
|
||||
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
|
||||
using Descriptor1dType =
|
||||
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>;
|
||||
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnnestedDescriptorType{}))>;
|
||||
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
|
||||
|
||||
template <typename... ShapeDims, typename... IdxDims>
|
||||
__host__ __device__ constexpr static auto
|
||||
TransformDesc(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx,
|
||||
const FlattenDescriptorType& naive_descriptor)
|
||||
const UnnestedDescriptorType& naive_descriptor)
|
||||
{
|
||||
if constexpr(Tuple<IdxDims...>::Size() == I1)
|
||||
{
|
||||
@@ -256,48 +215,33 @@ struct Layout
|
||||
}
|
||||
|
||||
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
|
||||
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>;
|
||||
Shape{}, DefaultIdxsTupleType{}, UnnestedDescriptorType{}))>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr auto GetElementSpaceSize() const
|
||||
{
|
||||
return flatten_descriptor_.GetElementSpaceSize();
|
||||
return unnested_descriptor_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
__host__ __device__ Layout() = delete;
|
||||
|
||||
/**
|
||||
* \brief Layout constructor.
|
||||
*
|
||||
* \param shape Shape for layout.
|
||||
* \param strides Strides for layout (optional if tensor is packed).
|
||||
* \param unnested_descriptor Descriptor
|
||||
*/
|
||||
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides)
|
||||
: flatten_descriptor_{}, shape_(shape), strides_(strides)
|
||||
__host__ __device__ constexpr Layout(const Shape& shape,
|
||||
const UnnestedDescriptorType& unnested_descriptor)
|
||||
: shape_(shape)
|
||||
{
|
||||
// Construct if runtime mode
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
|
||||
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
|
||||
unnested_descriptor_ = unnested_descriptor;
|
||||
descriptor_1d_ = MakeMerge1d(shape_, unnested_descriptor_);
|
||||
merged_nests_descriptor_ =
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Layout constructor (with default packed column-major strides).
|
||||
*
|
||||
* \param shape Shape for layout.
|
||||
*/
|
||||
__host__ __device__ constexpr Layout(const Shape& shape)
|
||||
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
|
||||
{
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
|
||||
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
|
||||
merged_nests_descriptor_ =
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,9 +254,9 @@ struct Layout
|
||||
template <typename Idxs>
|
||||
__host__ __device__ constexpr index_t operator()() const
|
||||
{
|
||||
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(),
|
||||
static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(),
|
||||
"Compiletime operator used on runtime layout.");
|
||||
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{}));
|
||||
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{}));
|
||||
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
|
||||
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
|
||||
}
|
||||
@@ -339,7 +283,7 @@ struct Layout
|
||||
else
|
||||
{
|
||||
// Custom index, need to transform descriptor
|
||||
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_);
|
||||
const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_);
|
||||
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
|
||||
}
|
||||
}
|
||||
@@ -351,7 +295,7 @@ struct Layout
|
||||
* \return Calculated size.
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr index_t GetLength() const
|
||||
__host__ __device__ constexpr auto GetLength() const
|
||||
{
|
||||
const auto elem = shape_.At(Number<IDim>{});
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
|
||||
@@ -371,7 +315,7 @@ struct Layout
|
||||
*
|
||||
* \return Calculated size.
|
||||
*/
|
||||
__host__ __device__ constexpr index_t GetLengths() const
|
||||
__host__ __device__ constexpr auto GetLengths() const
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape_);
|
||||
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
|
||||
@@ -385,13 +329,6 @@ struct Layout
|
||||
*/
|
||||
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
|
||||
|
||||
/**
|
||||
* \brief Strides getter.
|
||||
*
|
||||
* \return Strides.
|
||||
*/
|
||||
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }
|
||||
|
||||
/**
|
||||
* \brief Get default lengths (tuple filled with Shape length elements).
|
||||
*
|
||||
@@ -417,17 +354,26 @@ struct Layout
|
||||
*
|
||||
* \return Default descriptor.
|
||||
*/
|
||||
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor()
|
||||
__host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const
|
||||
{
|
||||
return merged_nests_descriptor_;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get unnested descriptor (with unrolled dims)
|
||||
*
|
||||
* \return Flatten descriptor.
|
||||
*/
|
||||
__host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const
|
||||
{
|
||||
return unnested_descriptor_;
|
||||
}
|
||||
|
||||
private:
|
||||
FlattenDescriptorType flatten_descriptor_;
|
||||
UnnestedDescriptorType unnested_descriptor_;
|
||||
Descriptor1dType descriptor_1d_;
|
||||
MergedNestsDescriptorType merged_nests_descriptor_;
|
||||
const Shape shape_;
|
||||
const DeducedStrides strides_;
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
41
include/ck/wrapper/operations/copy.hpp
Normal file
41
include/ck/wrapper/operations/copy.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../utils/tensor_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
/**
|
||||
* \brief Perform generic copy between two tensors. Tensors must have the
|
||||
* same size.
|
||||
*
|
||||
* \param src_tensor Source tensor.
|
||||
* \param dst_tensor Destination tensor.
|
||||
*/
|
||||
template <typename SrcTensorType, typename DstTensorType>
|
||||
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
{
|
||||
if constexpr(!SrcTensorType::IsDynamicBuffer)
|
||||
{
|
||||
using SizeType = decltype(size(src_tensor));
|
||||
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
|
||||
}
|
||||
else if constexpr(!DstTensorType::IsDynamicBuffer)
|
||||
{
|
||||
using SizeType = decltype(size(dst_tensor));
|
||||
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < size(src_tensor); i++)
|
||||
{
|
||||
dst_tensor(i) = src_tensor(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -1,9 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils/tensor_utils.hpp"
|
||||
#include "utils/tensor_partition.hpp"
|
||||
#include "utils/layout_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -15,14 +16,14 @@ namespace wrapper {
|
||||
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
|
||||
* \tparam ElementType Element data type.
|
||||
* \tparam Shape Tensor shape (layout component).
|
||||
* \tparam Strides Tensor strides (layout component).
|
||||
* \tparam UnnestedDescriptorType Unnested descriptor (layout component).
|
||||
* \tparam NumVectors Number of vectors (only for VGPR, SGPR).
|
||||
* \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR).
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors, // param for Register memory
|
||||
index_t ScalarPerVector // param for Register memory
|
||||
>
|
||||
@@ -31,50 +32,20 @@ struct Tensor
|
||||
private:
|
||||
// Check if Tuple contains Slice object
|
||||
template <typename T>
|
||||
constexpr static bool IsSlicing(T&&)
|
||||
__host__ __device__ constexpr static bool IsSlicing(T&&)
|
||||
{
|
||||
return is_detected<is_slice, T>::value;
|
||||
}
|
||||
template <typename... Ts>
|
||||
constexpr static bool IsSlicing(Tuple<Ts...>&&)
|
||||
__host__ __device__ constexpr static bool IsSlicing(Tuple<Ts...>&&)
|
||||
{
|
||||
return (IsSlicing(Ts{}) || ...);
|
||||
}
|
||||
|
||||
// Calculate first index of new tensor after slice
|
||||
// It is needed to calculate offset for new tensor
|
||||
template <typename... Ts>
|
||||
constexpr auto GetStartIdxForSlicedTensor(const Tuple<Ts...>& idx) const
|
||||
{
|
||||
const auto start_idx_for_sliced_tensor = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
return GetStartIdxForSlicedTensor(idx.At(num_i));
|
||||
}
|
||||
else if constexpr(is_detected<is_slice,
|
||||
tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// if slice, return the beginning of the interval
|
||||
return idx.At(num_i).from_;
|
||||
}
|
||||
else
|
||||
{
|
||||
// if one dim selected
|
||||
return idx.At(num_i);
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
|
||||
return start_idx_for_sliced_tensor;
|
||||
}
|
||||
|
||||
// Calculate new tensor shape after slice
|
||||
template <typename... Ts, typename ShapeTmpType>
|
||||
constexpr auto GetShapeFromSlicedTensor(const Tuple<Ts...>& idx,
|
||||
const ShapeTmpType& shape) const
|
||||
__host__ __device__ constexpr auto GetShapeFromSlicedTensor(const Tuple<Ts...>& idx,
|
||||
const ShapeTmpType& shape) const
|
||||
{
|
||||
// Pack each value in tuple to remove empty tuples after generation
|
||||
auto new_shape = generate_tuple(
|
||||
@@ -112,67 +83,137 @@ struct Tensor
|
||||
return UnrollNestedTuple<0, 1>(new_shape);
|
||||
}
|
||||
|
||||
template <typename... Ts, typename StridesTmpType>
|
||||
constexpr auto GetStridesFromSlicedTensor(const Tuple<Ts...>& idx,
|
||||
const StridesTmpType& strides) const
|
||||
// Generate Freeze for each of nested shape
|
||||
template <typename T, typename ShapeTmpType>
|
||||
__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx,
|
||||
const ShapeTmpType& shape) const
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
// dimension offset from idx
|
||||
const auto dim = unrolled_shape.At(Number<i>{});
|
||||
const auto dim_idx = idx % dim;
|
||||
idx /= dim;
|
||||
return make_freeze_transform(dim_idx);
|
||||
},
|
||||
Number<decltype(unrolled_shape)::Size()>{});
|
||||
}
|
||||
|
||||
template <typename... Ts, typename ShapeTmpType>
|
||||
__host__ __device__ constexpr auto
|
||||
GetTransformsFromSlicedTensor(const Tuple<Ts...>& idx, const ShapeTmpType& shape) const
|
||||
{
|
||||
// Pack each value in tuple to remove empty tuples after generation
|
||||
auto new_strides = generate_tuple(
|
||||
auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
if constexpr(!IsSlicing(tuple_element_t<i.value, Tuple<Ts...>>{}))
|
||||
{
|
||||
// if tuple does not have any slice then we can remove dimension
|
||||
return Tuple<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
// if tuple then recurrence
|
||||
return make_tuple(
|
||||
GetStridesFromSlicedTensor(idx.At(num_i), strides.At(num_i)));
|
||||
}
|
||||
return GetTransformsFromSlicedTensor(idx.At(num_i), shape.At(num_i));
|
||||
}
|
||||
else if constexpr(is_detected<is_slice,
|
||||
tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// Stride will be the same
|
||||
return make_tuple(strides.At(num_i));
|
||||
|
||||
const auto from = idx.At(num_i).from_;
|
||||
const auto dim = shape.At(num_i);
|
||||
const auto range = idx.At(num_i).range(dim);
|
||||
return make_slice_transform(range, from, from + range);
|
||||
}
|
||||
else
|
||||
{
|
||||
// remove dimension for just value
|
||||
return Tuple<>{};
|
||||
return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
// Remove empty tuples (deleted elements) and return
|
||||
return UnrollNestedTuple<0, 1>(new_strides);
|
||||
return UnrollNestedTuple(transforms);
|
||||
}
|
||||
|
||||
// There is no output for Freeze transform
|
||||
template <index_t i, typename LowerIndex>
|
||||
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&) const
|
||||
{
|
||||
return Sequence<>{};
|
||||
}
|
||||
|
||||
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
__host__ __device__ constexpr auto
|
||||
GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&) const
|
||||
{
|
||||
return Sequence<i>{};
|
||||
}
|
||||
|
||||
template <index_t i>
|
||||
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) const
|
||||
{
|
||||
return Tuple<>{};
|
||||
}
|
||||
|
||||
template <index_t i, typename... Transforms>
|
||||
__host__ __device__ constexpr auto
|
||||
GenerateUpperDims(const Tuple<Transforms...>& transforms) const
|
||||
{
|
||||
constexpr auto num_transforms = Tuple<Transforms...>::Size();
|
||||
// Deduce Sequence element for specific transform
|
||||
const auto currect_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
|
||||
if constexpr(is_same_v<decltype(currect_elem), const Sequence<>>)
|
||||
{
|
||||
const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
|
||||
return concat_tuple(make_tuple(currect_elem), next_tuple);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Increase i if current_elem is Slice transform
|
||||
const auto next_tuple =
|
||||
GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
|
||||
return concat_tuple(make_tuple(currect_elem), next_tuple);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, typename ShapeTmpType, typename FlattenDescriptor>
|
||||
__host__ __device__ constexpr auto
|
||||
GetDescriptorFromSlicedTensor(const Tuple<Ts...>& idx,
|
||||
const ShapeTmpType& shape,
|
||||
const FlattenDescriptor& flatten_desc) const
|
||||
{
|
||||
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
|
||||
|
||||
const auto transforms = GetTransformsFromSlicedTensor(idx, shape);
|
||||
using TransformsTupleType = decltype(transforms);
|
||||
|
||||
const auto lower_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
|
||||
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
|
||||
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
public:
|
||||
using ElementSpaceSize = decltype(Layout<Shape, Strides>{
|
||||
Shape{}, Strides{}}.GetElementSpaceSize()); // SpaceSize type for buffer
|
||||
using TensorElementType = ElementType; // DataType
|
||||
using ElementSpaceSize = decltype(Layout<Shape, UnnestedDescriptorType>{
|
||||
Shape{}, UnnestedDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
|
||||
using TensorElementType = ElementType; // DataType
|
||||
|
||||
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
|
||||
static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
|
||||
BufferAddressSpace == MemoryTypeEnum ::Vgpr);
|
||||
|
||||
__host__ __device__ Tensor() = delete;
|
||||
__host__ __device__ Tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
|
||||
__host__ __device__ Tensor(ElementType* pointer,
|
||||
const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
: layout_(layout),
|
||||
buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize()))
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ Tensor(const Layout<Shape, Strides>& layout) : layout_(layout)
|
||||
__host__ __device__ Tensor(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
: layout_(layout)
|
||||
{
|
||||
static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr const Layout<Shape, Strides>& GetLayout() const
|
||||
__host__ __device__ constexpr const Layout<Shape, UnnestedDescriptorType>& GetLayout() const
|
||||
{
|
||||
return layout_;
|
||||
}
|
||||
@@ -182,21 +223,14 @@ struct Tensor
|
||||
__host__ __device__ auto operator[](const Tuple<Ts...>& idx) const
|
||||
{
|
||||
static_assert(IsDynamicBuffer, "Register slice is not supported");
|
||||
// Calculate offset based on first idx for new tensor
|
||||
const index_t offset = layout_(GetStartIdxForSlicedTensor(idx));
|
||||
const auto& shape = layout_.GetShape();
|
||||
auto new_shape = GetShapeFromSlicedTensor(idx, shape);
|
||||
|
||||
auto new_shape = GetShapeFromSlicedTensor(idx, layout_.GetShape());
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
auto new_layout = make_layout(new_shape);
|
||||
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto new_strides = GetStridesFromSlicedTensor(idx, layout_.GetStrides());
|
||||
auto new_layout = make_layout(new_shape, new_strides);
|
||||
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout);
|
||||
}
|
||||
const auto& flatten_desc = layout_.GetUnnestedDescriptor();
|
||||
auto new_desc = GetDescriptorFromSlicedTensor(idx, shape, flatten_desc);
|
||||
const auto new_layout =
|
||||
Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
|
||||
return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
|
||||
}
|
||||
|
||||
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
|
||||
@@ -222,18 +256,10 @@ struct Tensor
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
constexpr index_t offset =
|
||||
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_[Number<offset>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t offset =
|
||||
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_[Number<offset>{}];
|
||||
}
|
||||
constexpr index_t offset = Layout<Shape, UnnestedDescriptorType>{
|
||||
Shape{},
|
||||
UnnestedDescriptorType{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_[Number<offset>{}];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,18 +286,10 @@ struct Tensor
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
constexpr index_t offset =
|
||||
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_(Number<offset>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t offset =
|
||||
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_(Number<offset>{});
|
||||
}
|
||||
constexpr index_t offset = Layout<Shape, UnnestedDescriptorType>{
|
||||
Shape{},
|
||||
UnnestedDescriptorType{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_(Number<offset>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,6 +310,8 @@ struct Tensor
|
||||
return layout_.GetDefaultDescriptor();
|
||||
}
|
||||
|
||||
__host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; }
|
||||
|
||||
private:
|
||||
using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
|
||||
ElementType,
|
||||
@@ -306,7 +326,7 @@ struct Tensor
|
||||
// If register use static buffer, else use dynamic buffer
|
||||
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
|
||||
|
||||
const Layout<Shape, Strides> layout_;
|
||||
const Layout<Shape, UnnestedDescriptorType> layout_;
|
||||
Buffer buffer_;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,11 +22,57 @@ namespace wrapper {
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declaration
|
||||
template <typename Shape, typename Strides>
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
struct Layout;
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
namespace {
|
||||
// Generate packed (column-major) strides if not passed
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return Number<1>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return TupleReduce<Number<0>{}.value, i.value>([](auto x, auto y) { return x * y; },
|
||||
unrolled_shape);
|
||||
}
|
||||
},
|
||||
Number<decltype(unrolled_shape)::Size()>{});
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
if constexpr(is_same_v<LayoutStrides, Tuple<>>)
|
||||
{
|
||||
// if not passed, then generate
|
||||
const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/// @endcond
|
||||
|
||||
// make_*
|
||||
@@ -38,10 +84,10 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
|
||||
const Strides& strides)
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
|
||||
{
|
||||
return Layout<Shape, Strides>(shape, strides);
|
||||
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, strides));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -52,9 +98,10 @@ __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& sh
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr Layout<Shape, Tuple<>> make_layout(const Shape& shape)
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape)
|
||||
{
|
||||
return Layout<Shape, Tuple<>>(shape);
|
||||
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, Tuple<>{}));
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
@@ -89,26 +136,51 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
|
||||
* \param layout Layout to create sub layout.
|
||||
* \return Requsted sub layout.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
|
||||
template <index_t idx, typename Shape, typename FlattenDesc>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto& new_shape = get<idx>(shape);
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto new_shape = get<idx>(shape);
|
||||
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
|
||||
"Shape of sub layout must be tuple");
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
// If stride not passed, create without strides
|
||||
return make_layout(new_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& strides = layout.GetStrides();
|
||||
const auto& new_strides = get<idx>(strides);
|
||||
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
|
||||
"Strides of sub layout must be tuple");
|
||||
return make_layout(new_shape, new_strides);
|
||||
}
|
||||
|
||||
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
|
||||
constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size();
|
||||
constexpr auto shape_offset = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size();
|
||||
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
// Compare Idx with shape
|
||||
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
|
||||
{
|
||||
// Remove dimension
|
||||
return make_freeze_transform(Number<0>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(unrolled_shape.At(i));
|
||||
}
|
||||
},
|
||||
Number<old_shape_dims>{});
|
||||
|
||||
const auto lower_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
|
||||
const auto upper_dims = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
|
||||
return Sequence<>{};
|
||||
|
||||
else
|
||||
{
|
||||
return Sequence<i.value - shape_offset>{};
|
||||
}
|
||||
},
|
||||
Number<old_shape_dims>{});
|
||||
|
||||
const auto& flatten_desc = layout.GetUnnestedDescriptor();
|
||||
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -142,8 +214,8 @@ __host__ __device__ T constexpr size(const T& dim)
|
||||
* \param layout Layout to get Shape of.
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
template <index_t idx, typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return layout.template GetLength<idx>();
|
||||
}
|
||||
@@ -155,7 +227,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename... ShapeDims>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
__host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
|
||||
@@ -168,8 +240,8 @@ __host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
* \param layout Layout to calculate shape size.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return layout.GetLengths();
|
||||
}
|
||||
@@ -182,7 +254,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename... Ts>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple)
|
||||
__host__ __device__ constexpr auto size(const Tuple<Ts...>& tuple)
|
||||
{
|
||||
return size(tuple.At(Number<idx>{}));
|
||||
}
|
||||
@@ -208,8 +280,9 @@ __host__ __device__ constexpr auto size(const T& elem)
|
||||
* \param layout Layout to calculate rank.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout<Shape, Strides>& layout)
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto
|
||||
rank([[maybe_unused]] const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return Shape::Size();
|
||||
}
|
||||
@@ -261,8 +334,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
|
||||
* \param layout Layout to calculate depth.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout)
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
return TupleDepth(shape);
|
||||
@@ -307,26 +380,14 @@ __host__ __device__ constexpr auto depth(const T& elem)
|
||||
return depth(get<Idxs...>(elem));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout strides.
|
||||
*
|
||||
* \param layout Layout to get strides from.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr const auto& stride(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetStrides();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout shape.
|
||||
*
|
||||
* \param layout Layout to get shape from.
|
||||
* \return Requsted shape.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr const auto& shape(const Layout<Shape, Strides>& layout)
|
||||
template <typename LayoutType>
|
||||
__host__ __device__ constexpr const auto& shape(const LayoutType& layout)
|
||||
{
|
||||
return layout.GetShape();
|
||||
}
|
||||
|
||||
285
include/ck/wrapper/utils/tensor_partition.hpp
Normal file
285
include/ck/wrapper/utils/tensor_partition.hpp
Normal file
@@ -0,0 +1,285 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensor_utils.hpp"
|
||||
#include "layout_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
namespace {
|
||||
// Calculate shape for partition based on number of threads per each dim and
|
||||
// previous shape
|
||||
template <typename... Ts, typename... Ls>
|
||||
__host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ls...>& thread_lengths)
|
||||
{
|
||||
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
return CalculateLocalPartitionShape(shape.At(num_i), thread_lengths.At(num_i));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto slice_len = shape.At(num_i) / thread_lengths.At(num_i);
|
||||
return slice_len;
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
}
|
||||
|
||||
// Calculate shape for partition based on number of threads per each dim,
|
||||
// previous strides and steps
|
||||
template <typename... Ts, typename... Ls, typename... Steps, typename FlattenDescType>
|
||||
__host__ __device__ constexpr auto
|
||||
CalculateLocalPartitionDescriptor(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ls...>& thread_lengths,
|
||||
const Tuple<Steps...>& steps,
|
||||
const FlattenDescType& flatten_desc)
|
||||
{
|
||||
|
||||
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
|
||||
const auto unrolled_thread_lengths = UnrollNestedTuple(thread_lengths);
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
constexpr auto dims = decltype(unrolled_thread_lengths)::Size();
|
||||
|
||||
using UnrolledStepsType = decltype(UnrollNestedTuple(steps));
|
||||
|
||||
using I1 = Number<1>;
|
||||
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
// By default raked partition
|
||||
const auto partition_stride = unrolled_thread_lengths.At(num_i);
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(partition_stride));
|
||||
}
|
||||
else if constexpr(!is_same_v<tuple_element_t<i.value, UnrolledStepsType>, index_t>)
|
||||
{
|
||||
// Compiletime partition
|
||||
if constexpr(is_same_v<tuple_element_t<i.value, UnrolledStepsType>, I1>)
|
||||
{
|
||||
// raked
|
||||
const auto partition_stride = unrolled_thread_lengths.At(num_i);
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(partition_stride));
|
||||
}
|
||||
else
|
||||
{
|
||||
// packed
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(I1{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Runtime partition
|
||||
if(steps.At(num_i) == 1)
|
||||
{
|
||||
// raked
|
||||
const auto partition_stride = unrolled_thread_lengths.At(num_i);
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(partition_stride));
|
||||
}
|
||||
else
|
||||
{
|
||||
// packed
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(I1{}));
|
||||
}
|
||||
}
|
||||
},
|
||||
Number<dims>{});
|
||||
|
||||
const auto lower_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
|
||||
const auto upper_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
|
||||
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
template <typename... Ls, typename... Steps>
|
||||
__host__ __device__ constexpr auto CalculateLayoutOffsetIdxImpl(const Tuple<Ls...>& thread_lengths,
|
||||
const Tuple<Steps...>& steps,
|
||||
index_t& thread_id)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ls...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
return CalculateLayoutOffsetIdxImpl(
|
||||
thread_lengths.At(num_i), Tuple<>{}, thread_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
return CalculateLayoutOffsetIdxImpl(
|
||||
thread_lengths.At(num_i), steps.At(num_i), thread_id);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Update thread_id after each dim
|
||||
const auto dim_thread_id = thread_id % thread_lengths.At(num_i);
|
||||
thread_id /= thread_lengths.At(num_i);
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
return dim_thread_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Apply step
|
||||
return steps.At(num_i) * dim_thread_id;
|
||||
}
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
}
|
||||
|
||||
// Convert integer thread_idx to tuple index with steps applied
|
||||
template <typename... Ls, typename... Steps>
|
||||
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Ls...>& thread_lengths,
|
||||
const Tuple<Steps...>& steps,
|
||||
const index_t thread_id)
|
||||
{
|
||||
// Create tmp thread_id copy for CalculateLayoutOffsetIdxImpl updates
|
||||
index_t thread_id_copy = thread_id;
|
||||
return CalculateLayoutOffsetIdxImpl(thread_lengths, steps, thread_id_copy);
|
||||
}
|
||||
|
||||
// Apply steps to index represented as tuple
|
||||
template <typename... Steps, typename... Idxs>
|
||||
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Steps...>& steps,
|
||||
const Tuple<Idxs...>& block_idxs)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Idxs...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
return CalculateLayoutOffsetIdx(Tuple<>{}, block_idxs.At(num_i));
|
||||
}
|
||||
else
|
||||
{
|
||||
return CalculateLayoutOffsetIdx(steps.At(num_i), block_idxs.At(num_i));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
return block_idxs.At(num_i);
|
||||
}
|
||||
else
|
||||
{
|
||||
// apply step
|
||||
return steps.At(num_i) * block_idxs.At(num_i);
|
||||
}
|
||||
}
|
||||
},
|
||||
Number<Tuple<Idxs...>::Size()>{});
|
||||
}
|
||||
|
||||
// User passes only shape per block to the make_local_tile function. This function calculates
|
||||
// block layout based on the shape.
|
||||
template <typename... Ts, typename... BlockDims>
|
||||
__host__ __device__ constexpr auto CalculateBlockLengths(const Tuple<Ts...>& shape,
|
||||
const Tuple<BlockDims...>& tile_shape)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
return CalculateBlockLengths(shape.At(num_i), tile_shape.At(num_i));
|
||||
}
|
||||
else
|
||||
{
|
||||
return shape.At(num_i) / tile_shape.At(num_i);
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* \brief Create local partition for thread.
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param thread_lengths Layout of threads.
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \param steps Thread step (default=1, raked partition)
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple, typename StepsTuple = Tuple<>>
|
||||
__host__ __device__ constexpr auto make_local_partition(const TensorType& tensor,
|
||||
const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id,
|
||||
const StepsTuple steps = StepsTuple{})
|
||||
{
|
||||
// Create shape, strides and layout for new partition tensor
|
||||
const auto partition_shape = CalculateLocalPartitionShape(shape(tensor), thread_lengths);
|
||||
// Create new descriptor and layout
|
||||
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
|
||||
auto partition_desc =
|
||||
CalculateLocalPartitionDescriptor(shape(tensor), thread_lengths, steps, flatten_desc);
|
||||
const auto partition_layout = Layout<decltype(partition_shape), decltype(partition_desc)>(
|
||||
partition_shape, partition_desc);
|
||||
// Calculate offset for new partition tensor
|
||||
const auto offset_idx = CalculateLayoutOffsetIdx(thread_lengths, steps, thread_id);
|
||||
const auto partition_offset = layout(tensor)(offset_idx);
|
||||
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + partition_offset,
|
||||
partition_layout);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local tile for thread block.
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_idx Block index represented as tuple.
|
||||
* \param steps Block step (default=1, raked partition)
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType,
|
||||
typename BlockShapeTuple,
|
||||
typename BlockIdxTuple,
|
||||
typename StepsTuple = Tuple<>>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const BlockIdxTuple& block_idx,
|
||||
const StepsTuple steps = StepsTuple{})
|
||||
{
|
||||
// Create block lengths, strides and layout for new tile tensor
|
||||
const auto block_lengths = CalculateBlockLengths(shape(tensor), tile_shape);
|
||||
// Create new descriptor and layout
|
||||
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
|
||||
auto tile_desc =
|
||||
CalculateLocalPartitionDescriptor(tile_shape, block_lengths, steps, flatten_desc);
|
||||
const auto tile_layout = Layout<remove_reference_t<decltype(tile_shape)>, decltype(tile_desc)>(
|
||||
tile_shape, tile_desc);
|
||||
// Calculate offset for new partition tensor
|
||||
const auto offset_idx = CalculateLayoutOffsetIdx(steps, block_idx);
|
||||
const auto tile_offset = layout(tensor)(offset_idx);
|
||||
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + tile_offset,
|
||||
tile_layout);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -27,12 +27,12 @@ using MemoryTypeEnum = AddressSpaceEnum;
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declarations
|
||||
template <typename Shape, typename Strides>
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
struct Layout;
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors, // params for Register memory
|
||||
index_t ScalarPerVector // param for Register memory
|
||||
>
|
||||
@@ -98,11 +98,19 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
* \param layout Tensor layout.
|
||||
* \return Constructed tensor.
|
||||
*/
|
||||
template <MemoryTypeEnum MemoryType, typename ElementType, typename Shape, typename Strides>
|
||||
constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
|
||||
template <MemoryTypeEnum MemoryType,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename UnnestedDescriptorType>
|
||||
constexpr auto make_tensor(ElementType* pointer,
|
||||
const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return Tensor<MemoryType, ElementType, Shape, Strides, 0 /*NumVectors*/, 0 /*ScalarPerVector*/>(
|
||||
pointer, layout);
|
||||
return Tensor<MemoryType,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
0 /*NumVectors*/,
|
||||
0 /*ScalarPerVector*/>(pointer, layout);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -112,19 +120,21 @@ constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& l
|
||||
* \tparam NumVectors Number of vectors.
|
||||
* \tparam ScalarPerVector Scalars per vector.
|
||||
* \tparam ElementType Memory data type.
|
||||
* \param layout Tensor layout.
|
||||
* \return Constructed tensor.
|
||||
*/
|
||||
template <MemoryTypeEnum MemoryType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides>
|
||||
constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
|
||||
typename ElementType>
|
||||
constexpr auto make_register_tensor()
|
||||
{
|
||||
static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported");
|
||||
return Tensor<MemoryType, ElementType, Shape, Strides, NumVectors, ScalarPerVector>(layout);
|
||||
const auto layout = make_layout(make_tuple(Number<NumVectors>{}), make_tuple(Number<1>{}));
|
||||
return Tensor<MemoryType,
|
||||
ElementType,
|
||||
Tuple<Number<NumVectors>>,
|
||||
std::remove_const_t<remove_reference_t<decltype(layout.GetUnnestedDescriptor())>>,
|
||||
NumVectors,
|
||||
ScalarPerVector>(layout);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -136,12 +146,15 @@ constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
layout(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr const auto& layout(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return tensor.GetLayout();
|
||||
}
|
||||
@@ -157,12 +170,15 @@ template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
size(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr auto size(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return size<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
@@ -178,12 +194,15 @@ template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
rank(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr auto rank(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return rank<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
@@ -199,35 +218,19 @@ template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
depth(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr auto depth(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return depth<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor strides.
|
||||
*
|
||||
* \param tensor Tensor to get strides from.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return stride(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor shape.
|
||||
*
|
||||
@@ -237,12 +240,15 @@ stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors,
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
shape(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr const auto& shape(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return shape(tensor.GetLayout());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user