mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Introduce wrapper for layout (#1054)
* Introduce wrapper for layout * Extend functionality * Fix for getLength * Comment fixes * Add comments and remove not needed getters
This commit is contained in:
2
example/64_tensor_transforms/CMakeLists.txt
Normal file
2
example/64_tensor_transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_example_executable(example_tensor_transform tensor_transform.cpp)
|
||||
add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
|
||||
150
example/64_tensor_transforms/tensor_transform.cpp
Normal file
150
example/64_tensor_transforms/tensor_transform.cpp
Normal file
@@ -0,0 +1,150 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
|
||||
static constexpr auto I0 = ck::Number<0>{};
|
||||
static constexpr auto I1 = ck::Number<1>{};
|
||||
static constexpr auto I2 = ck::Number<2>{};
|
||||
|
||||
using DataType = int;
|
||||
|
||||
template <typename Desc>
|
||||
void Print1d(const Desc& desc)
|
||||
{
|
||||
std::cout << "Print1d" << std::endl;
|
||||
for(ck::index_t w = 0; w < desc.GetLength(I0); w++)
|
||||
{
|
||||
std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename Desc>
|
||||
void Print2d(const Desc& desc)
|
||||
{
|
||||
std::cout << "Print2d" << std::endl;
|
||||
for(ck::index_t h = 0; h < desc.GetLength(I0); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < desc.GetLength(I1); w++)
|
||||
{
|
||||
std::cout << desc.CalculateOffset(ck::make_tuple(h, w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Desc>
|
||||
void Print3dCustom(const Desc& desc)
|
||||
{
|
||||
std::cout << "Print3dCustom" << std::endl;
|
||||
for(ck::index_t d = 0; d < desc.GetLength(I0); d++)
|
||||
{
|
||||
for(ck::index_t h = 0; h < desc.GetLength(I1); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < desc.GetLength(I2); w++)
|
||||
{
|
||||
std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
// Tensor descriptor traverse in row-major (need to reverse dims)
|
||||
std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl;
|
||||
// Basic descriptor 0, 1, 2, ... 30, 31
|
||||
// (dims:4,8 strides:1,4)
|
||||
const auto desc_4x8_s1x4 =
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}),
|
||||
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}));
|
||||
std::cout << "dims:4,8 strides:1,4" << std::endl;
|
||||
Print2d(desc_4x8_s1x4);
|
||||
|
||||
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
|
||||
constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{});
|
||||
std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:4,(2,4) strides:2,(1,8)
|
||||
const auto desc_4x2x4_s2x1x8 =
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8));
|
||||
// Transform to 2d (column-major, need to to reverse dims)
|
||||
const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor(
|
||||
desc_4x2x4_s2x1x8,
|
||||
ck::make_tuple(ck::make_pass_through_transform(4),
|
||||
ck::make_merge_transform(ck::make_tuple(4, 2))),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
|
||||
|
||||
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
|
||||
Print2d(desc_4x2x4_s2x1x8_merged);
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:(2,2),(2,4) strides:((1,4),(2,8)
|
||||
const auto desc_2x2x2x4_s1x4x2x8 =
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
|
||||
// Transform to 2d
|
||||
const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
|
||||
ck::make_merge_transform(ck::make_tuple(4, 2))),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
|
||||
// Transform to 3d
|
||||
const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8,
|
||||
ck::make_tuple(ck::make_pass_through_transform(2),
|
||||
ck::make_pass_through_transform(2),
|
||||
ck::make_merge_transform(ck::make_tuple(4, 2))),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
|
||||
|
||||
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
|
||||
Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d);
|
||||
Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d);
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:((2,2),2),4 strides:((1,4),2),8
|
||||
// Transform to 2d
|
||||
const auto desc_2x2x2x4_s1x4x2x8_nested =
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
|
||||
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8_nested,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
|
||||
ck::make_pass_through_transform(2),
|
||||
ck::make_pass_through_transform(4)),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
|
||||
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8_nested,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))),
|
||||
ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}));
|
||||
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8_nested_merged_3d,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)),
|
||||
ck::make_pass_through_transform(4)),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
|
||||
|
||||
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
|
||||
Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d);
|
||||
Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d);
|
||||
Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d);
|
||||
|
||||
return 0;
|
||||
}
|
||||
119
example/64_tensor_transforms/tensor_transform_using_wrapper.cpp
Normal file
119
example/64_tensor_transforms/tensor_transform_using_wrapper.cpp
Normal file
@@ -0,0 +1,119 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
#include "tensor_transform_wrapper.hpp"
|
||||
|
||||
using DataType = int;
|
||||
|
||||
template <typename Layout>
|
||||
void Print1d(const Layout& layout)
|
||||
{
|
||||
std::cout << "Print1d" << std::endl;
|
||||
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
void Print2d(const Layout& layout)
|
||||
{
|
||||
std::cout << "Print2d" << std::endl;
|
||||
for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(h, w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Print in (x,y),z pattern
|
||||
template <typename Layout>
|
||||
void Print3dCustom(const Layout& layout)
|
||||
{
|
||||
std::cout << "Print3dCustom" << std::endl;
|
||||
for(ck::index_t d = 0;
|
||||
d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout));
|
||||
d++)
|
||||
{
|
||||
for(ck::index_t h = 0;
|
||||
h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout));
|
||||
h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
// Layout traverse in row-major
|
||||
std::cout << "Note: Layout traverse in column-major" << std::endl;
|
||||
// Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor)
|
||||
// (dims:4,8 strides:1,4)
|
||||
const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{});
|
||||
const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8);
|
||||
std::cout << "dims:4,8 strides:1,4" << std::endl;
|
||||
Print2d(layout_4x8_s1x4);
|
||||
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
|
||||
constexpr ck::index_t offset_1x1 = layout_4x8_s1x4.template operator()<Cord1x1Type>();
|
||||
std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor)
|
||||
// dims:4,(2,4) strides:2,(1,8)
|
||||
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
|
||||
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
|
||||
const auto layout_4x2x4_s2x1x8 =
|
||||
ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
|
||||
|
||||
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
|
||||
Print2d(layout_4x2x4_s2x1x8);
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:(2,2),(2,4) strides:((1,4),(2,8)
|
||||
const auto shape_2x2x2x4 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}),
|
||||
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}));
|
||||
const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}),
|
||||
ck::make_tuple(ck::Number<2>{}, ck::Number<8>{}));
|
||||
static const auto layout_2x2x2x4_s1x4x2x8 =
|
||||
ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8);
|
||||
|
||||
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
|
||||
Print2d(layout_2x2x2x4_s1x4x2x8);
|
||||
Print3dCustom(layout_2x2x2x4_s1x4x2x8);
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:((2,2),2),4 strides:((1,4),2),8
|
||||
// Transform to 2d
|
||||
const auto shape_2x2x2x4_nested = ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<2>{}),
|
||||
ck::Number<4>{});
|
||||
const auto strides_s1x4x2x8_nested = ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}),
|
||||
ck::Number<8>{});
|
||||
static const auto layout_2x2x2x4_s1x4x2x8_nested =
|
||||
ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested);
|
||||
|
||||
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
|
||||
Print1d(layout_2x2x2x4_s1x4x2x8_nested);
|
||||
Print2d(layout_2x2x2x4_s1x4x2x8_nested);
|
||||
Print3dCustom(layout_2x2x2x4_s1x4x2x8_nested);
|
||||
|
||||
return 0;
|
||||
}
|
||||
425
example/64_tensor_transforms/tensor_transform_wrapper.hpp
Normal file
425
example/64_tensor_transforms/tensor_transform_wrapper.hpp
Normal file
@@ -0,0 +1,425 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/sequence_helper.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_transform_wrapper {
|
||||
|
||||
/**
|
||||
* \brief Layout wrapper
|
||||
*
|
||||
* \details
|
||||
* Layout wrapper that performs the tensor descriptor logic.
|
||||
*
|
||||
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
|
||||
* (dynamic layout). It is possible to pass nested shapes
|
||||
* (e.g. ((4, 2), 2)), nested dimensions are merged.
|
||||
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
|
||||
* (dynamic layout). Stride tuple should be nested if shape tuple is
|
||||
* nested.
|
||||
*/
|
||||
template <typename Shape, typename Strides = Tuple<>>
|
||||
struct Layout
|
||||
{
|
||||
private:
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
// Generate packed (column-major) strides if not passed
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& tuple)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
|
||||
tuple);
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
}
|
||||
|
||||
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
|
||||
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
|
||||
// If tuple is element, then pass through (sequence with one element)
|
||||
template <typename Idx, typename... Ts>
|
||||
__host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&)
|
||||
{
|
||||
if constexpr(Idx::value == 0)
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// Return Sequence for the first tuple
|
||||
constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
|
||||
tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
|
||||
using LowerDimsSequence =
|
||||
typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
|
||||
return LowerDimsSequence::Reverse();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return first element
|
||||
return Sequence<0>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Get previous element using recurence (in compile-time)
|
||||
using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
|
||||
const auto next_seq_val = PreviousSeqT::At(I0) + 1;
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
|
||||
tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
|
||||
using LowerDimsSequence =
|
||||
typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
|
||||
type;
|
||||
return LowerDimsSequence::Reverse();
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<next_seq_val>{};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate over nested tuples in shape
|
||||
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
|
||||
// Example idx: (1, 1), 1, 1
|
||||
// Example shape: (2, (2, 2)), 2, (2, 2)
|
||||
// Unrolled shape: 2, (2, 2), 2, (2, 2)
|
||||
template <typename... ShapeDims, typename... IdxDims>
|
||||
__host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx)
|
||||
{
|
||||
if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
|
||||
{
|
||||
// Index unrolled to flatten, return shape
|
||||
return shape;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Iterate over shape tuple elements:
|
||||
// 1. If corresponding idx element is tuple then return (will be unrolled)
|
||||
// 2. If no, pack in tuple. It will be restored during unroll.
|
||||
auto unrolled_shape_via_idx = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(is_detected<is_tuple,
|
||||
tuple_element_t<i, Tuple<IdxDims...>>>::value)
|
||||
{
|
||||
return shape.At(i);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(shape.At(i));
|
||||
}
|
||||
},
|
||||
Number<Tuple<IdxDims...>::Size()>{});
|
||||
|
||||
// Unroll and process next step
|
||||
return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx),
|
||||
UnrollNestedTuple<0, 1>(idx));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... ShapeDims, typename DescriptorToMerge>
|
||||
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
|
||||
DescriptorToMerge& desc)
|
||||
{
|
||||
// Reverse each element in tuple
|
||||
using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape)));
|
||||
const auto merge_elems = ReversedUnrolledShape{};
|
||||
|
||||
// Generate reverted indexes (column major traverse)
|
||||
using MergeElemsSequence =
|
||||
typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type;
|
||||
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
|
||||
const auto upper_dims = make_tuple(Sequence<0>{});
|
||||
// Merge to 1d
|
||||
return transform_tensor_descriptor(
|
||||
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
// Merge nested shape dims
|
||||
// Input desc shape: 2, 2, 2, 2, 2, 2
|
||||
// Example idx: 1, 1, 1, 1
|
||||
// Example shape: 2, (2, 2), 2, (2, 2)
|
||||
// Merged shape: 2, 4, 2, 4
|
||||
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
|
||||
__host__ __device__ constexpr static auto
|
||||
MakeMerges(const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
|
||||
{
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
// Compare Idx with shape
|
||||
if constexpr(is_detected<is_tuple,
|
||||
tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
|
||||
!is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
|
||||
{
|
||||
// If shape element is tuple and idx element is Number, then merge
|
||||
// Unroll and reverse tuple to traverse column-major
|
||||
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
|
||||
return make_merge_transform(merge_elems);
|
||||
}
|
||||
else
|
||||
{
|
||||
// If shape element is integer and idx element is tuple, passed idx is wrong
|
||||
static_assert(
|
||||
!(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
|
||||
is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
|
||||
"Wrong Idx for layout()");
|
||||
// If shape element has the same type as idx element, then pass through
|
||||
return make_pass_through_transform(shape.At(i));
|
||||
}
|
||||
},
|
||||
Number<Tuple<ShapeDims...>::Size()>{});
|
||||
|
||||
const auto lower_dims =
|
||||
generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
|
||||
Number<Tuple<ShapeDims...>::Size()>{});
|
||||
const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
|
||||
Number<Tuple<ShapeDims...>::Size()>{});
|
||||
|
||||
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
template <typename... ShapeDims, typename... IdxDims>
|
||||
__host__ __device__ constexpr auto TransformDesc(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx) const
|
||||
{
|
||||
if constexpr(Tuple<IdxDims...>::Size() == I1)
|
||||
{
|
||||
// 1d idx path
|
||||
return MakeMerge1d(shape, descriptor_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Merge nested shape dims
|
||||
// Example idx: (1, 1), 1, 1
|
||||
// Example shape: (2, (2, 2)), 2, (2, 2)
|
||||
// Merged shape: (2, 4), 2, 4
|
||||
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
|
||||
"Idx rank and Shape rank must be the same (except 1d).");
|
||||
// Unroll while IdxDims is nested
|
||||
const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx);
|
||||
// Transform correct form of shape
|
||||
return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
|
||||
if constexpr(ck::is_same_v<LayoutStrides, Tuple<>>)
|
||||
{
|
||||
// If shape is packed
|
||||
const auto column_major_packed_strides =
|
||||
GenerateColumnMajorPackedStrides(unrolled_shape);
|
||||
return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
using NaiveDescriptorType = remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, Strides{}))>;
|
||||
|
||||
/**
|
||||
* \brief Layout constructor.
|
||||
*
|
||||
* \param shape Shape for layout.
|
||||
* \param strides Strides for layout (optional if tensor is packed).
|
||||
* \return Layout object.
|
||||
*/
|
||||
__host__ __device__ Layout() = delete;
|
||||
__host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{}
|
||||
{
|
||||
// Construct if runtime mode
|
||||
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
// Keep only shape, strides are not need for transforms
|
||||
shape_ = shape;
|
||||
descriptor_ = MakeNaiveDescriptor(shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ Layout(const Shape& shape) : descriptor_{}
|
||||
{
|
||||
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
shape_ = shape;
|
||||
descriptor_ = MakeNaiveDescriptor(shape, Strides{});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Returns real offset to element in runtime.
|
||||
*
|
||||
* \tparam Idxs Tuple of indexes.
|
||||
* \return Calculated offset.
|
||||
*/
|
||||
template <typename Idxs>
|
||||
__host__ __device__ constexpr index_t operator()() const
|
||||
{
|
||||
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
|
||||
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
|
||||
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Returns real offset to element in compile time.
|
||||
*
|
||||
* \param Idx Tuple of indexes.
|
||||
* \return Calculated offset.
|
||||
*/
|
||||
template <typename... Ts>
|
||||
__host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
|
||||
{
|
||||
// Static to construct transformed_desc only once
|
||||
static const auto transformed_desc = TransformDesc(shape_, Idx);
|
||||
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Length getter (product if tuple).
|
||||
*
|
||||
* \tparam IDim Tuple of indexes or index.
|
||||
* \return Calculated size.
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr index_t GetLength() const
|
||||
{
|
||||
const auto elem = shape_.At(Number<IDim>{});
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
|
||||
{
|
||||
const auto unrolled_element = UnrollNestedTuple(elem);
|
||||
return TupleReduce<I0.value, unrolled_element.Size()>(
|
||||
[](auto x, auto y) { return x * y; }, unrolled_element);
|
||||
}
|
||||
else
|
||||
{
|
||||
return elem;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Layout size getter (product of shape).
|
||||
*
|
||||
* \return Calculated size.
|
||||
*/
|
||||
__host__ __device__ constexpr index_t GetLength() const
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape_);
|
||||
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
|
||||
unrolled_shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Dimension getter.
|
||||
*
|
||||
* \tparam IDim Dimension idx.
|
||||
* \return Calculated size.
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr auto Get() const
|
||||
{
|
||||
const auto elem = shape_.At(Number<IDim>{});
|
||||
return elem;
|
||||
}
|
||||
|
||||
private:
|
||||
NaiveDescriptorType descriptor_;
|
||||
Shape shape_;
|
||||
};
|
||||
|
||||
// Layout helpers
|
||||
// Length getter (product if tuple)
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.template GetLength<idx>();
|
||||
}
|
||||
|
||||
// Get shape size (product of dims if tuple)
|
||||
template <typename... ShapeDims>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
{
|
||||
using UnrolledShape = decltype(UnrollNestedTuple(shape));
|
||||
return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; },
|
||||
UnrolledShape{});
|
||||
}
|
||||
|
||||
// Get dim size (could be returned from get function)
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr size(const T& dim)
|
||||
{
|
||||
return dim;
|
||||
}
|
||||
|
||||
// Get layout size (product of shapes)
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetLength();
|
||||
}
|
||||
|
||||
// Get shape element size
|
||||
template <index_t idx, typename... ShapeDims>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
{
|
||||
return size(shape.At(Number<idx>{}));
|
||||
}
|
||||
|
||||
// Dim getter (tuple if tuple)
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.template Get<idx>();
|
||||
}
|
||||
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
|
||||
const Strides& strides)
|
||||
{
|
||||
return Layout<Shape, Strides>(shape, strides);
|
||||
}
|
||||
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr Layout<Shape> make_layout(const Shape& shape)
|
||||
{
|
||||
return Layout<Shape>(shape);
|
||||
}
|
||||
|
||||
} // namespace tensor_transform_wrapper
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user