mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add tensor partition and generic copy for ck wrapper (#1108)
* Add tensor partition and generic copy for ck wrapper * Update changelog * Stylistic fixes * Change shape/strides logic to descriptor transforms * Fixes * Fix client example * Fix comments
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,11 +22,57 @@ namespace wrapper {
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declaration
|
||||
template <typename Shape, typename Strides>
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
struct Layout;
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
namespace {
|
||||
// Generate packed (column-major) strides if not passed
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return Number<1>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return TupleReduce<Number<0>{}.value, i.value>([](auto x, auto y) { return x * y; },
|
||||
unrolled_shape);
|
||||
}
|
||||
},
|
||||
Number<decltype(unrolled_shape)::Size()>{});
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
if constexpr(is_same_v<LayoutStrides, Tuple<>>)
|
||||
{
|
||||
// if not passed, then generate
|
||||
const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/// @endcond
|
||||
|
||||
// make_*
|
||||
@@ -38,10 +84,10 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
|
||||
const Strides& strides)
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
|
||||
{
|
||||
return Layout<Shape, Strides>(shape, strides);
|
||||
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, strides));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -52,9 +98,10 @@ __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& sh
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr Layout<Shape, Tuple<>> make_layout(const Shape& shape)
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape)
|
||||
{
|
||||
return Layout<Shape, Tuple<>>(shape);
|
||||
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, Tuple<>{}));
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
@@ -89,26 +136,51 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
|
||||
* \param layout Layout to create sub layout.
|
||||
* \return Requsted sub layout.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
|
||||
template <index_t idx, typename Shape, typename FlattenDesc>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto& new_shape = get<idx>(shape);
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto new_shape = get<idx>(shape);
|
||||
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
|
||||
"Shape of sub layout must be tuple");
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
// If stride not passed, create without strides
|
||||
return make_layout(new_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& strides = layout.GetStrides();
|
||||
const auto& new_strides = get<idx>(strides);
|
||||
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
|
||||
"Strides of sub layout must be tuple");
|
||||
return make_layout(new_shape, new_strides);
|
||||
}
|
||||
|
||||
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
|
||||
constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size();
|
||||
constexpr auto shape_offset = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size();
|
||||
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
// Compare Idx with shape
|
||||
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
|
||||
{
|
||||
// Remove dimension
|
||||
return make_freeze_transform(Number<0>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(unrolled_shape.At(i));
|
||||
}
|
||||
},
|
||||
Number<old_shape_dims>{});
|
||||
|
||||
const auto lower_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
|
||||
const auto upper_dims = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
|
||||
return Sequence<>{};
|
||||
|
||||
else
|
||||
{
|
||||
return Sequence<i.value - shape_offset>{};
|
||||
}
|
||||
},
|
||||
Number<old_shape_dims>{});
|
||||
|
||||
const auto& flatten_desc = layout.GetUnnestedDescriptor();
|
||||
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -142,8 +214,8 @@ __host__ __device__ T constexpr size(const T& dim)
|
||||
* \param layout Layout to get Shape of.
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
template <index_t idx, typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return layout.template GetLength<idx>();
|
||||
}
|
||||
@@ -155,7 +227,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename... ShapeDims>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
__host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
|
||||
@@ -168,8 +240,8 @@ __host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
* \param layout Layout to calculate shape size.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return layout.GetLengths();
|
||||
}
|
||||
@@ -182,7 +254,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename... Ts>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple)
|
||||
__host__ __device__ constexpr auto size(const Tuple<Ts...>& tuple)
|
||||
{
|
||||
return size(tuple.At(Number<idx>{}));
|
||||
}
|
||||
@@ -208,8 +280,9 @@ __host__ __device__ constexpr auto size(const T& elem)
|
||||
* \param layout Layout to calculate rank.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout<Shape, Strides>& layout)
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto
|
||||
rank([[maybe_unused]] const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return Shape::Size();
|
||||
}
|
||||
@@ -261,8 +334,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
|
||||
* \param layout Layout to calculate depth.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout)
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
return TupleDepth(shape);
|
||||
@@ -307,26 +380,14 @@ __host__ __device__ constexpr auto depth(const T& elem)
|
||||
return depth(get<Idxs...>(elem));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout strides.
|
||||
*
|
||||
* \param layout Layout to get strides from.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr const auto& stride(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetStrides();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout shape.
|
||||
*
|
||||
* \param layout Layout to get shape from.
|
||||
* \return Requsted shape.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr const auto& shape(const Layout<Shape, Strides>& layout)
|
||||
template <typename LayoutType>
|
||||
__host__ __device__ constexpr const auto& shape(const LayoutType& layout)
|
||||
{
|
||||
return layout.GetShape();
|
||||
}
|
||||
|
||||
285
include/ck/wrapper/utils/tensor_partition.hpp
Normal file
285
include/ck/wrapper/utils/tensor_partition.hpp
Normal file
@@ -0,0 +1,285 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensor_utils.hpp"
|
||||
#include "layout_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
namespace {
|
||||
// Calculate shape for partition based on number of threads per each dim and
|
||||
// previous shape
|
||||
template <typename... Ts, typename... Ls>
|
||||
__host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ls...>& thread_lengths)
|
||||
{
|
||||
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
return CalculateLocalPartitionShape(shape.At(num_i), thread_lengths.At(num_i));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto slice_len = shape.At(num_i) / thread_lengths.At(num_i);
|
||||
return slice_len;
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
}
|
||||
|
||||
// Calculate shape for partition based on number of threads per each dim,
|
||||
// previous strides and steps
|
||||
template <typename... Ts, typename... Ls, typename... Steps, typename FlattenDescType>
|
||||
__host__ __device__ constexpr auto
|
||||
CalculateLocalPartitionDescriptor(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ls...>& thread_lengths,
|
||||
const Tuple<Steps...>& steps,
|
||||
const FlattenDescType& flatten_desc)
|
||||
{
|
||||
|
||||
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
|
||||
const auto unrolled_thread_lengths = UnrollNestedTuple(thread_lengths);
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
constexpr auto dims = decltype(unrolled_thread_lengths)::Size();
|
||||
|
||||
using UnrolledStepsType = decltype(UnrollNestedTuple(steps));
|
||||
|
||||
using I1 = Number<1>;
|
||||
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
// By default raked partition
|
||||
const auto partition_stride = unrolled_thread_lengths.At(num_i);
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(partition_stride));
|
||||
}
|
||||
else if constexpr(!is_same_v<tuple_element_t<i.value, UnrolledStepsType>, index_t>)
|
||||
{
|
||||
// Compiletime partition
|
||||
if constexpr(is_same_v<tuple_element_t<i.value, UnrolledStepsType>, I1>)
|
||||
{
|
||||
// raked
|
||||
const auto partition_stride = unrolled_thread_lengths.At(num_i);
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(partition_stride));
|
||||
}
|
||||
else
|
||||
{
|
||||
// packed
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(I1{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Runtime partition
|
||||
if(steps.At(num_i) == 1)
|
||||
{
|
||||
// raked
|
||||
const auto partition_stride = unrolled_thread_lengths.At(num_i);
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(partition_stride));
|
||||
}
|
||||
else
|
||||
{
|
||||
// packed
|
||||
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
|
||||
make_tuple(I1{}));
|
||||
}
|
||||
}
|
||||
},
|
||||
Number<dims>{});
|
||||
|
||||
const auto lower_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
|
||||
const auto upper_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
|
||||
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
template <typename... Ls, typename... Steps>
|
||||
__host__ __device__ constexpr auto CalculateLayoutOffsetIdxImpl(const Tuple<Ls...>& thread_lengths,
|
||||
const Tuple<Steps...>& steps,
|
||||
index_t& thread_id)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ls...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
return CalculateLayoutOffsetIdxImpl(
|
||||
thread_lengths.At(num_i), Tuple<>{}, thread_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
return CalculateLayoutOffsetIdxImpl(
|
||||
thread_lengths.At(num_i), steps.At(num_i), thread_id);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Update thread_id after each dim
|
||||
const auto dim_thread_id = thread_id % thread_lengths.At(num_i);
|
||||
thread_id /= thread_lengths.At(num_i);
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
return dim_thread_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Apply step
|
||||
return steps.At(num_i) * dim_thread_id;
|
||||
}
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
}
|
||||
|
||||
// Convert integer thread_idx to tuple index with steps applied
|
||||
template <typename... Ls, typename... Steps>
|
||||
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Ls...>& thread_lengths,
|
||||
const Tuple<Steps...>& steps,
|
||||
const index_t thread_id)
|
||||
{
|
||||
// Create tmp thread_id copy for CalculateLayoutOffsetIdxImpl updates
|
||||
index_t thread_id_copy = thread_id;
|
||||
return CalculateLayoutOffsetIdxImpl(thread_lengths, steps, thread_id_copy);
|
||||
}
|
||||
|
||||
// Apply steps to index represented as tuple
|
||||
template <typename... Steps, typename... Idxs>
|
||||
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Steps...>& steps,
|
||||
const Tuple<Idxs...>& block_idxs)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Idxs...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
return CalculateLayoutOffsetIdx(Tuple<>{}, block_idxs.At(num_i));
|
||||
}
|
||||
else
|
||||
{
|
||||
return CalculateLayoutOffsetIdx(steps.At(num_i), block_idxs.At(num_i));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
|
||||
{
|
||||
return block_idxs.At(num_i);
|
||||
}
|
||||
else
|
||||
{
|
||||
// apply step
|
||||
return steps.At(num_i) * block_idxs.At(num_i);
|
||||
}
|
||||
}
|
||||
},
|
||||
Number<Tuple<Idxs...>::Size()>{});
|
||||
}
|
||||
|
||||
// User passes only shape per block to the make_local_tile function. This function calculates
|
||||
// block layout based on the shape.
|
||||
template <typename... Ts, typename... BlockDims>
|
||||
__host__ __device__ constexpr auto CalculateBlockLengths(const Tuple<Ts...>& shape,
|
||||
const Tuple<BlockDims...>& tile_shape)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
return CalculateBlockLengths(shape.At(num_i), tile_shape.At(num_i));
|
||||
}
|
||||
else
|
||||
{
|
||||
return shape.At(num_i) / tile_shape.At(num_i);
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* \brief Create local partition for thread.
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param thread_lengths Layout of threads.
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \param steps Thread step (default=1, raked partition)
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple, typename StepsTuple = Tuple<>>
|
||||
__host__ __device__ constexpr auto make_local_partition(const TensorType& tensor,
|
||||
const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id,
|
||||
const StepsTuple steps = StepsTuple{})
|
||||
{
|
||||
// Create shape, strides and layout for new partition tensor
|
||||
const auto partition_shape = CalculateLocalPartitionShape(shape(tensor), thread_lengths);
|
||||
// Create new descriptor and layout
|
||||
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
|
||||
auto partition_desc =
|
||||
CalculateLocalPartitionDescriptor(shape(tensor), thread_lengths, steps, flatten_desc);
|
||||
const auto partition_layout = Layout<decltype(partition_shape), decltype(partition_desc)>(
|
||||
partition_shape, partition_desc);
|
||||
// Calculate offset for new partition tensor
|
||||
const auto offset_idx = CalculateLayoutOffsetIdx(thread_lengths, steps, thread_id);
|
||||
const auto partition_offset = layout(tensor)(offset_idx);
|
||||
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + partition_offset,
|
||||
partition_layout);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local tile for thread block.
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_idx Block index represented as tuple.
|
||||
* \param steps Block step (default=1, raked partition)
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType,
|
||||
typename BlockShapeTuple,
|
||||
typename BlockIdxTuple,
|
||||
typename StepsTuple = Tuple<>>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const BlockIdxTuple& block_idx,
|
||||
const StepsTuple steps = StepsTuple{})
|
||||
{
|
||||
// Create block lengths, strides and layout for new tile tensor
|
||||
const auto block_lengths = CalculateBlockLengths(shape(tensor), tile_shape);
|
||||
// Create new descriptor and layout
|
||||
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
|
||||
auto tile_desc =
|
||||
CalculateLocalPartitionDescriptor(tile_shape, block_lengths, steps, flatten_desc);
|
||||
const auto tile_layout = Layout<remove_reference_t<decltype(tile_shape)>, decltype(tile_desc)>(
|
||||
tile_shape, tile_desc);
|
||||
// Calculate offset for new partition tensor
|
||||
const auto offset_idx = CalculateLayoutOffsetIdx(steps, block_idx);
|
||||
const auto tile_offset = layout(tensor)(offset_idx);
|
||||
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + tile_offset,
|
||||
tile_layout);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -27,12 +27,12 @@ using MemoryTypeEnum = AddressSpaceEnum;
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declarations
|
||||
template <typename Shape, typename Strides>
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
struct Layout;
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors, // params for Register memory
|
||||
index_t ScalarPerVector // param for Register memory
|
||||
>
|
||||
@@ -98,11 +98,19 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
* \param layout Tensor layout.
|
||||
* \return Constructed tensor.
|
||||
*/
|
||||
template <MemoryTypeEnum MemoryType, typename ElementType, typename Shape, typename Strides>
|
||||
constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
|
||||
template <MemoryTypeEnum MemoryType,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename UnnestedDescriptorType>
|
||||
constexpr auto make_tensor(ElementType* pointer,
|
||||
const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
{
|
||||
return Tensor<MemoryType, ElementType, Shape, Strides, 0 /*NumVectors*/, 0 /*ScalarPerVector*/>(
|
||||
pointer, layout);
|
||||
return Tensor<MemoryType,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
0 /*NumVectors*/,
|
||||
0 /*ScalarPerVector*/>(pointer, layout);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -112,19 +120,21 @@ constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& l
|
||||
* \tparam NumVectors Number of vectors.
|
||||
* \tparam ScalarPerVector Scalars per vector.
|
||||
* \tparam ElementType Memory data type.
|
||||
* \param layout Tensor layout.
|
||||
* \return Constructed tensor.
|
||||
*/
|
||||
template <MemoryTypeEnum MemoryType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides>
|
||||
constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
|
||||
typename ElementType>
|
||||
constexpr auto make_register_tensor()
|
||||
{
|
||||
static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported");
|
||||
return Tensor<MemoryType, ElementType, Shape, Strides, NumVectors, ScalarPerVector>(layout);
|
||||
const auto layout = make_layout(make_tuple(Number<NumVectors>{}), make_tuple(Number<1>{}));
|
||||
return Tensor<MemoryType,
|
||||
ElementType,
|
||||
Tuple<Number<NumVectors>>,
|
||||
std::remove_const_t<remove_reference_t<decltype(layout.GetUnnestedDescriptor())>>,
|
||||
NumVectors,
|
||||
ScalarPerVector>(layout);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -136,12 +146,15 @@ constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
layout(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr const auto& layout(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return tensor.GetLayout();
|
||||
}
|
||||
@@ -157,12 +170,15 @@ template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
size(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr auto size(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return size<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
@@ -178,12 +194,15 @@ template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
rank(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr auto rank(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return rank<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
@@ -199,35 +218,19 @@ template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
depth(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr auto depth(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return depth<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor strides.
|
||||
*
|
||||
* \param tensor Tensor to get strides from.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return stride(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor shape.
|
||||
*
|
||||
@@ -237,12 +240,15 @@ stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors,
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
typename UnnestedDescriptorType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
shape(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
__host__ __device__ constexpr const auto& shape(const Tensor<BufferAddressSpace,
|
||||
ElementType,
|
||||
Shape,
|
||||
UnnestedDescriptorType,
|
||||
NumVectors,
|
||||
ScalarPerVector>& tensor)
|
||||
{
|
||||
return shape(tensor.GetLayout());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user