From 8ff845f2c4aa7bbdd728b7639a79e0b6932c6dab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 30 Nov 2023 12:11:43 +0100 Subject: [PATCH] Introduce wrapper for layout (#1054) * Introduce wrapper for layout * Extend functionality * Fix for getLength * Comment fixes * Add comments and remove not needed getters --- example/64_tensor_transforms/CMakeLists.txt | 2 + .../64_tensor_transforms/tensor_transform.cpp | 150 +++++++ .../tensor_transform_using_wrapper.cpp | 119 +++++ .../tensor_transform_wrapper.hpp | 425 ++++++++++++++++++ include/ck/utility/tuple_helper.hpp | 88 ++++ 5 files changed, 784 insertions(+) create mode 100644 example/64_tensor_transforms/CMakeLists.txt create mode 100644 example/64_tensor_transforms/tensor_transform.cpp create mode 100644 example/64_tensor_transforms/tensor_transform_using_wrapper.cpp create mode 100644 example/64_tensor_transforms/tensor_transform_wrapper.hpp diff --git a/example/64_tensor_transforms/CMakeLists.txt b/example/64_tensor_transforms/CMakeLists.txt new file mode 100644 index 0000000000..9d14a410e3 --- /dev/null +++ b/example/64_tensor_transforms/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_tensor_transform tensor_transform.cpp) +add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) diff --git a/example/64_tensor_transforms/tensor_transform.cpp b/example/64_tensor_transforms/tensor_transform.cpp new file mode 100644 index 0000000000..41ceec1cb5 --- /dev/null +++ b/example/64_tensor_transforms/tensor_transform.cpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +static constexpr auto I0 = ck::Number<0>{}; +static constexpr auto I1 = ck::Number<1>{}; +static constexpr auto I2 = ck::Number<2>{}; + +using DataType = int; + +template +void Print1d(const Desc& desc) +{ + std::cout << "Print1d" << std::endl; + for(ck::index_t w = 0; w < desc.GetLength(I0); w++) + { + std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " "; + } + std::cout << std::endl; +} + +template +void Print2d(const Desc& desc) +{ + std::cout << "Print2d" << std::endl; + for(ck::index_t h = 0; h < desc.GetLength(I0); h++) + { + for(ck::index_t w = 0; w < desc.GetLength(I1); w++) + { + std::cout << desc.CalculateOffset(ck::make_tuple(h, w)) << " "; + } + std::cout << std::endl; + } +} + +template +void Print3dCustom(const Desc& desc) +{ + std::cout << "Print3dCustom" << std::endl; + for(ck::index_t d = 0; d < desc.GetLength(I0); d++) + { + for(ck::index_t h = 0; h < desc.GetLength(I1); h++) + { + for(ck::index_t w = 0; w < desc.GetLength(I2); w++) + { + std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } +} + +int main() +{ + // Tensor descriptor traverse in row-major (need to reverse dims) + std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl; + // Basic descriptor 0, 1, 2, ... 30, 31 + // (dims:4,8 strides:1,4) + const auto desc_4x8_s1x4 = + ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{})); + std::cout << "dims:4,8 strides:1,4" << std::endl; + Print2d(desc_4x8_s1x4); + + using Cord1x1Type = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{}); + std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl; + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:4,(2,4) strides:2,(1,8) + const auto desc_4x2x4_s2x1x8 = + ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8)); + // Transform to 2d (column-major, need to to reverse dims) + const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor( + desc_4x2x4_s2x1x8, + ck::make_tuple(ck::make_pass_through_transform(4), + ck::make_merge_transform(ck::make_tuple(4, 2))), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + + std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; + Print2d(desc_4x2x4_s2x1x8_merged); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:(2,2),(2,4) strides:((1,4),(2,8) + const auto desc_2x2x2x4_s1x4x2x8 = + ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8)); + // Transform to 2d + const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), + ck::make_merge_transform(ck::make_tuple(4, 2))), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + // Transform to 3d + const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8, + ck::make_tuple(ck::make_pass_through_transform(2), + ck::make_pass_through_transform(2), + ck::make_merge_transform(ck::make_tuple(4, 2))), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + + std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; + Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d); + Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:((2,2),2),4 strides:((1,4),2),8 + // Transform to 2d + const auto desc_2x2x2x4_s1x4x2x8_nested = + ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8)); + const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8_nested, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), + ck::make_pass_through_transform(2), + ck::make_pass_through_transform(4)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8_nested, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))), + ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8_nested_merged_3d, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)), + ck::make_pass_through_transform(4)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + + std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; + Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d); + Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d); + Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d); + + return 0; +} diff --git a/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp b/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp new file mode 100644 index 0000000000..df2449e99d --- /dev/null +++ b/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence.hpp" + +#include "tensor_transform_wrapper.hpp" + +using DataType = int; + +template +void Print1d(const Layout& layout) +{ + std::cout << "Print1d" << std::endl; + for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++) + { + std::cout << layout(ck::make_tuple(w)) << " "; + } + std::cout << std::endl; +} + +template +void Print2d(const Layout& layout) +{ + std::cout << "Print2d" << std::endl; + for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++) + { + for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(h, w)) << " "; + } + std::cout << std::endl; + } +} + +// Print in (x,y),z pattern +template +void Print3dCustom(const Layout& layout) +{ + std::cout << "Print3dCustom" << std::endl; + for(ck::index_t d = 0; + d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout)); + d++) + { + for(ck::index_t h = 0; + h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout)); + h++) + { + for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } +} + +int main() +{ + // Layout traverse in row-major + std::cout << "Note: Layout traverse in column-major" << std::endl; + // Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor) + // (dims:4,8 strides:1,4) + const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}); + const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8); + std::cout << "dims:4,8 strides:1,4" << std::endl; + Print2d(layout_4x8_s1x4); + using Cord1x1Type = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t offset_1x1 = layout_4x8_s1x4.template operator()(); + std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl; + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor) + // dims:4,(2,4) strides:2,(1,8) + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout_4x2x4_s2x1x8 = + ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; + Print2d(layout_4x2x4_s2x1x8); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:(2,2),(2,4) strides:((1,4),(2,8) + const auto shape_2x2x2x4 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), + ck::make_tuple(ck::Number<2>{}, ck::Number<4>{})); + const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), + ck::make_tuple(ck::Number<2>{}, ck::Number<8>{})); + static const auto layout_2x2x2x4_s1x4x2x8 = + ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); + + std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; + Print2d(layout_2x2x2x4_s1x4x2x8); + Print3dCustom(layout_2x2x2x4_s1x4x2x8); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:((2,2),2),4 strides:((1,4),2),8 + // Transform to 2d + const auto shape_2x2x2x4_nested = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<2>{}), + ck::Number<4>{}); + const auto strides_s1x4x2x8_nested = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}), + ck::Number<8>{}); + static const auto layout_2x2x2x4_s1x4x2x8_nested = + ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); + + std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; + Print1d(layout_2x2x2x4_s1x4x2x8_nested); + Print2d(layout_2x2x2x4_s1x4x2x8_nested); + Print3dCustom(layout_2x2x2x4_s1x4x2x8_nested); + + return 0; +} diff --git a/example/64_tensor_transforms/tensor_transform_wrapper.hpp b/example/64_tensor_transforms/tensor_transform_wrapper.hpp new file mode 100644 index 0000000000..71cd6091f8 --- /dev/null +++ b/example/64_tensor_transforms/tensor_transform_wrapper.hpp @@ -0,0 +1,425 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/is_detected.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +namespace ck { +namespace tensor_transform_wrapper { + +/** + * \brief Layout wrapper + * + * \details + * Layout wrapper that performs the tensor descriptor logic. + * + * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t + * (dynamic layout). It is possible to pass nested shapes + * (e.g. ((4, 2), 2)), nested dimensions are merged. + * \tparam Strides Tuple of Number<> (for compile-time layout) or index_t + * (dynamic layout). Stride tuple should be nested if shape tuple is + * nested. + */ +template > +struct Layout +{ + private: + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template + using is_tuple = decltype(std::declval().IsTuple()); + + // Generate packed (column-major) strides if not passed + template + __host__ __device__ constexpr static auto + GenerateColumnMajorPackedStrides(const Tuple& tuple) + { + return generate_tuple( + [&](auto i) { + if constexpr(i.value == 0) + { + return I1; + } + else + { + return TupleReduce([](auto x, auto y) { return x * y; }, + tuple); + } + }, + Number::Size()>{}); + } + + // Generate LowerDims in Compile-time for MergeTrasform using passed Type + // If element of Tuple is also tuple, then merge (generate sequence for merge) + // If tuple is element, then pass through (sequence with one element) + template + __host__ __device__ constexpr static auto GenerateLowerDim(const Tuple&) + { + if constexpr(Idx::value == 0) + { + if constexpr(is_detected>>::value) + { + // Return Sequence for the first tuple + constexpr index_t merge_nelems = decltype(UnrollNestedTuple( + tuple_element_t>{}))::Size(); + using LowerDimsSequence = + typename arithmetic_sequence_gen<0, merge_nelems, 1>::type; + return LowerDimsSequence::Reverse(); + } + else + { + // Return first element + return Sequence<0>{}; + } + } + else + { + // Get previous element using recurence (in compile-time) + using PreviousSeqT = decltype(GenerateLowerDim>(Tuple{})); + const auto next_seq_val = PreviousSeqT::At(I0) + 1; + if constexpr(is_detected>>::value) + { + constexpr index_t merge_nelems = decltype(UnrollNestedTuple( + tuple_element_t>{}))::Size(); + using LowerDimsSequence = + typename arithmetic_sequence_gen:: + type; + return LowerDimsSequence::Reverse(); + } + else + { + return Sequence{}; + } + } + } + + // Iterate over nested tuples in shape + // Unroll nested tuples to align Tuple to Tuple + // Example idx: (1, 1), 1, 1 + // Example shape: (2, (2, 2)), 2, (2, 2) + // Unrolled shape: 2, (2, 2), 2, (2, 2) + template + __host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple& shape, + const Tuple& idx) + { + if constexpr(!IsNestedTuple(Tuple{})) + { + // Index unrolled to flatten, return shape + return shape; + } + else + { + // Iterate over shape tuple elements: + // 1. If corresponding idx element is tuple then return (will be unrolled) + // 2. If no, pack in tuple. It will be restored during unroll. + auto unrolled_shape_via_idx = generate_tuple( + [&](auto i) { + if constexpr(is_detected>>::value) + { + return shape.At(i); + } + else + { + return make_tuple(shape.At(i)); + } + }, + Number::Size()>{}); + + // Unroll and process next step + return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx), + UnrollNestedTuple<0, 1>(idx)); + } + } + + template + __host__ __device__ constexpr static auto MakeMerge1d(const Tuple& shape, + DescriptorToMerge& desc) + { + // Reverse each element in tuple + using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape))); + const auto merge_elems = ReversedUnrolledShape{}; + + // Generate reverted indexes (column major traverse) + using MergeElemsSequence = + typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type; + const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); + const auto upper_dims = make_tuple(Sequence<0>{}); + // Merge to 1d + return transform_tensor_descriptor( + desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); + } + + // Merge nested shape dims + // Input desc shape: 2, 2, 2, 2, 2, 2 + // Example idx: 1, 1, 1, 1 + // Example shape: 2, (2, 2), 2, (2, 2) + // Merged shape: 2, 4, 2, 4 + template + __host__ __device__ constexpr static auto + MakeMerges(const Tuple& shape, const Tuple&, DescriptorToMerge& desc) + { + const auto transforms = generate_tuple( + [&](auto i) { + // Compare Idx with shape + if constexpr(is_detected>>::value && + !is_detected>>::value) + { + // If shape element is tuple and idx element is Number, then merge + // Unroll and reverse tuple to traverse column-major + const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i))); + return make_merge_transform(merge_elems); + } + else + { + // If shape element is integer and idx element is tuple, passed idx is wrong + static_assert( + !(!is_detected>>::value && + is_detected>>::value), + "Wrong Idx for layout()"); + // If shape element has the same type as idx element, then pass through + return make_pass_through_transform(shape.At(i)); + } + }, + Number::Size()>{}); + + const auto lower_dims = + generate_tuple([&](auto i) { return GenerateLowerDim>(shape); }, + Number::Size()>{}); + const auto upper_dims = generate_tuple([&](auto i) { return Sequence{}; }, + Number::Size()>{}); + + return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); + } + + template + __host__ __device__ constexpr auto TransformDesc(const Tuple& shape, + const Tuple& idx) const + { + if constexpr(Tuple::Size() == I1) + { + // 1d idx path + return MakeMerge1d(shape, descriptor_); + } + else + { + // Merge nested shape dims + // Example idx: (1, 1), 1, 1 + // Example shape: (2, (2, 2)), 2, (2, 2) + // Merged shape: (2, 4), 2, 4 + static_assert(Tuple::Size() == Tuple::Size(), + "Idx rank and Shape rank must be the same (except 1d)."); + // Unroll while IdxDims is nested + const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx); + // Transform correct form of shape + return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_); + } + } + + template + __host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape, + const LayoutStrides& strides) + { + const auto unrolled_shape = UnrollNestedTuple(shape); + + if constexpr(ck::is_same_v>) + { + // If shape is packed + const auto column_major_packed_strides = + GenerateColumnMajorPackedStrides(unrolled_shape); + return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides); + } + else + { + const auto unrolled_strides = UnrollNestedTuple(strides); + static_assert(unrolled_shape.Size() == unrolled_strides.Size(), + "Size of strides and shape are not consistent."); + return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); + } + } + + public: + using NaiveDescriptorType = remove_cvref_t; + + /** + * \brief Layout constructor. + * + * \param shape Shape for layout. + * \param strides Strides for layout (optional if tensor is packed). + * \return Layout object. + */ + __host__ __device__ Layout() = delete; + __host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{} + { + // Construct if runtime mode + if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) + { + // Keep only shape, strides are not need for transforms + shape_ = shape; + descriptor_ = MakeNaiveDescriptor(shape, strides); + } + } + + __host__ __device__ Layout(const Shape& shape) : descriptor_{} + { + if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) + { + shape_ = shape; + descriptor_ = MakeNaiveDescriptor(shape, Strides{}); + } + } + + /** + * \brief Returns real offset to element in runtime. + * + * \tparam Idxs Tuple of indexes. + * \return Calculated offset. + */ + template + __host__ __device__ constexpr index_t operator()() const + { + using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{})); + using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); + return TransformedDesc{}.CalculateOffset(UnrolledIdx{}); + } + + /** + * \brief Returns real offset to element in compile time. + * + * \param Idx Tuple of indexes. + * \return Calculated offset. + */ + template + __host__ __device__ index_t operator()(const Tuple& Idx) const + { + // Static to construct transformed_desc only once + static const auto transformed_desc = TransformDesc(shape_, Idx); + return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); + } + + /** + * \brief Length getter (product if tuple). + * + * \tparam IDim Tuple of indexes or index. + * \return Calculated size. + */ + template + __host__ __device__ constexpr index_t GetLength() const + { + const auto elem = shape_.At(Number{}); + if constexpr(is_detected>::value) + { + const auto unrolled_element = UnrollNestedTuple(elem); + return TupleReduce( + [](auto x, auto y) { return x * y; }, unrolled_element); + } + else + { + return elem; + } + } + + /** + * \brief Layout size getter (product of shape). + * + * \return Calculated size. + */ + __host__ __device__ constexpr index_t GetLength() const + { + const auto unrolled_shape = UnrollNestedTuple(shape_); + return TupleReduce([](auto x, auto y) { return x * y; }, + unrolled_shape); + } + + /** + * \brief Dimension getter. + * + * \tparam IDim Dimension idx. + * \return Calculated size. + */ + template + __host__ __device__ constexpr auto Get() const + { + const auto elem = shape_.At(Number{}); + return elem; + } + + private: + NaiveDescriptorType descriptor_; + Shape shape_; +}; + +// Layout helpers +// Length getter (product if tuple) +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.template GetLength(); +} + +// Get shape size (product of dims if tuple) +template +__host__ __device__ constexpr index_t size(const Tuple& shape) +{ + using UnrolledShape = decltype(UnrollNestedTuple(shape)); + return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; }, + UnrolledShape{}); +} + +// Get dim size (could be returned from get function) +template +__host__ __device__ T constexpr size(const T& dim) +{ + return dim; +} + +// Get layout size (product of shapes) +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.GetLength(); +} + +// Get shape element size +template +__host__ __device__ constexpr index_t size(const Tuple& shape) +{ + return size(shape.At(Number{})); +} + +// Dim getter (tuple if tuple) +template +__host__ __device__ constexpr auto get(const Layout& layout) +{ + return layout.template Get(); +} + +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape, + const Strides& strides) +{ + return Layout(shape, strides); +} + +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape) +{ + return Layout(shape); +} + +} // namespace tensor_transform_wrapper +} // namespace ck diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index e39ae1c23d..d7b492fe66 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -5,6 +5,7 @@ #include "functional4.hpp" #include "tuple.hpp" +#include "is_detected.hpp" namespace ck { @@ -33,6 +34,28 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple& ty); } +template +__host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuple& ty) +{ + return unpack2( + [&](auto... zs) { return Tuple{std::forward(zs)...}; }, + tx, + ty); +} + +// Support any number of tuples to concat (also 1) +template +__host__ __device__ constexpr auto concat_tuple(const Tuple& tx) +{ + return tx; +} + +template +__host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuples&... tuples) +{ + return concat_tuple(tx, concat_tuple(tuples...)); +} + namespace detail { template @@ -78,4 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); } +// By default unroll to the flatten +template +__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element) +{ + return element; +} + +template +__host__ __device__ constexpr auto UnrollNestedTuple(const T& element) +{ + return make_tuple(element); +} + +template +__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple& tuple) +{ + if constexpr(Depth == MaxDepth) + { + return tuple; + } + else + { + return unpack( + [&](auto&&... ts) { + return concat_tuple(UnrollNestedTuple(ts)...); + }, + tuple); + } +} + +template +__host__ __device__ constexpr auto TupleReverse(const Tuple& tuple) +{ + return generate_tuple( + [&](auto i) { + using Idx = Number::Size() - i - 1>; + return tuple.At(Idx{}); + }, + Number::Size()>{}); +} + +// Reduce tuple values in specific range using Function +template +__host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple& tuple) +{ + static_assert(Idx < End, "Wrong parameters for TupleReduce"); + if constexpr(Idx + 1 == End) + { + return tuple.At(Number{}); + } + else + { + return f(tuple.At(Number{}), TupleReduce(f, tuple)); + } +} + +template +using is_tuple = decltype(std::declval().IsTuple()); + +template +__host__ __device__ constexpr auto IsNestedTuple(const Tuple&) +{ + return (is_detected::value || ...); +} + } // namespace ck