diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e46a4ab4b..3da22fc790 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ None - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) - Support for Batched Gemm DL (#732) +- Introduce wrapper sublibrary (limited functionality) (#1071) ### Changes - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) diff --git a/client_example/25_tensor_transforms/CMakeLists.txt b/client_example/25_tensor_transforms/CMakeLists.txt new file mode 100644 index 0000000000..d1543fb0ef --- /dev/null +++ b/client_example/25_tensor_transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(client_tensor_transform tensor_transform.cpp) +target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations) +add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) +target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) diff --git a/example/64_tensor_transforms/tensor_transform.cpp b/client_example/25_tensor_transforms/tensor_transform.cpp similarity index 100% rename from example/64_tensor_transforms/tensor_transform.cpp rename to client_example/25_tensor_transforms/tensor_transform.cpp diff --git a/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp b/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp similarity index 74% rename from example/64_tensor_transforms/tensor_transform_using_wrapper.cpp rename to client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp index df2449e99d..de9fcde0b4 100644 --- a/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp +++ b/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp @@ -9,7 +9,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/sequence.hpp" -#include "tensor_transform_wrapper.hpp" +#include "ck/wrapper/layout.hpp" using DataType = int; @@ -17,7 +17,7 @@ template void Print1d(const Layout& layout) { std::cout << "Print1d" << std::endl; - for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size(layout); w++) { std::cout << layout(ck::make_tuple(w)) << " "; } @@ -28,9 +28,9 @@ template void Print2d(const Layout& layout) { std::cout << "Print2d" << std::endl; - for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++) + for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++) { - for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) { std::cout << layout(ck::make_tuple(h, w)) << " "; } @@ -43,15 +43,11 @@ template void Print3dCustom(const Layout& layout) { std::cout << "Print3dCustom" << std::endl; - for(ck::index_t d = 0; - d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout)); - d++) + for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++) { - for(ck::index_t h = 0; - h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout)); - h++) + for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++) { - for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) { std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " "; } @@ -68,7 +64,7 @@ int main() // Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor) // (dims:4,8 strides:1,4) const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}); - const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8); + const auto layout_4x8_s1x4 = ck::wrapper::make_layout(shape_4x8); std::cout << "dims:4,8 strides:1,4" << std::endl; Print2d(layout_4x8_s1x4); using Cord1x1Type = ck::Tuple, ck::Number<1>>; @@ -77,10 +73,9 @@ int main() // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor) // dims:4,(2,4) strides:2,(1,8) - const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); - const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); - const auto layout_4x2x4_s2x1x8 = - ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout_4x2x4_s2x1x8 = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; Print2d(layout_4x2x4_s2x1x8); @@ -92,7 +87,7 @@ int main() const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::make_tuple(ck::Number<2>{}, ck::Number<8>{})); static const auto layout_2x2x2x4_s1x4x2x8 = - ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); + ck::wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; Print2d(layout_2x2x2x4_s1x4x2x8); @@ -108,7 +103,7 @@ int main() ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}), ck::Number<8>{}); static const auto layout_2x2x2x4_s1x4x2x8_nested = - ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); + ck::wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; Print1d(layout_2x2x2x4_s1x4x2x8_nested); diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index 2594422095..fac9e138e1 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -778,7 +778,9 @@ WARN_LOGFILE = INPUT = ../../include/ck/tensor_operation/gpu/grid \ ../../include/ck/tensor_operation/gpu/block \ ../../include/ck/tensor_operation/gpu/thread \ - ../../library/include/ck/library/utility + ../../library/include/ck/library/utility \ + ../../include/ck/wrapper + # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/docs/index.rst b/docs/index.rst index 51c0c862ae..8c4aaa2b3d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -34,6 +34,7 @@ Current CK library are structured into 4 layers: * "Templated Tile Operators" layer * "Templated Kernel and Invoker" layer * "Instantiated Kernel and Invoker" layer +* "Wrapper for tensor transform operations" * "Client API" layer .. image:: data/ck_layer.png @@ -50,6 +51,7 @@ The following is a list of CK documents in the suggested reading order: tutorial_hello_world dockerhub + wrapper Supported_Primitives_Guide API_Reference_Guide Contributors_Guide diff --git a/docs/wrapper.rst b/docs/wrapper.rst new file mode 100644 index 0000000000..64fb6a4031 --- /dev/null +++ b/docs/wrapper.rst @@ -0,0 +1,54 @@ +=============== +Wrapper +=============== + +------------------------------------- +Description +------------------------------------- + +.. note:: + + The wrapper is under development and its functionality is limited. + + +CK provides a lightweight wrapper for more complex operations implemented in +the library. It allows indexing of nested layouts using a simple interface +(avoiding complex descriptor transformations). + +Example: + +.. code-block:: c + + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; + for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(h, w)) << " "; + } + std::cout << std::endl; + } + +Output:: + + dims:4,(2,4) strides:2,(1,8) + 0 1 8 9 16 17 24 25 + 2 3 10 11 18 19 26 27 + 4 5 12 13 20 21 28 29 + 6 7 14 15 22 23 30 31 + +------------------------------------- +Layout +------------------------------------- + +.. doxygenstruct:: ck::wrapper::Layout + +------------------------------------- +Layout helpers +------------------------------------- + +.. doxygenfile:: layout_utils.hpp diff --git a/example/64_tensor_transforms/CMakeLists.txt b/example/64_tensor_transforms/CMakeLists.txt deleted file mode 100644 index 9d14a410e3..0000000000 --- a/example/64_tensor_transforms/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_example_executable(example_tensor_transform tensor_transform.cpp) -add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index d7b492fe66..75f2693f20 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -166,4 +166,16 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple&) return (is_detected::value || ...); } +template +__host__ __device__ constexpr auto TupleDepth(const T&) +{ + return depth; +} + +template +__host__ __device__ constexpr auto TupleDepth(const Tuple&) +{ + return math::max(TupleDepth(Ts{})...); +} + } // namespace ck diff --git a/example/64_tensor_transforms/tensor_transform_wrapper.hpp b/include/ck/wrapper/layout.hpp similarity index 68% rename from example/64_tensor_transforms/tensor_transform_wrapper.hpp rename to include/ck/wrapper/layout.hpp index 71cd6091f8..b337d88a1a 100644 --- a/example/64_tensor_transforms/tensor_transform_wrapper.hpp +++ b/include/ck/wrapper/layout.hpp @@ -3,27 +3,13 @@ #pragma once -#include "ck/ck.hpp" - -#include "ck/utility/number.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/utility/tuple_helper.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/utility/sequence_helper.hpp" -#include "ck/utility/is_detected.hpp" - -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/wrapper/layout_utils.hpp" namespace ck { -namespace tensor_transform_wrapper { +namespace wrapper { /** - * \brief Layout wrapper - * - * \details - * Layout wrapper that performs the tensor descriptor logic. + * \brief Layout wrapper that performs the tensor descriptor logic. * * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t * (dynamic layout). It is possible to pass nested shapes @@ -32,21 +18,19 @@ namespace tensor_transform_wrapper { * (dynamic layout). Stride tuple should be nested if shape tuple is * nested. */ -template > +template struct Layout { private: static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - template - using is_tuple = decltype(std::declval().IsTuple()); - // Generate packed (column-major) strides if not passed template __host__ __device__ constexpr static auto - GenerateColumnMajorPackedStrides(const Tuple& tuple) + GenerateColumnMajorPackedStrides(const Tuple& shape) { + const auto unrolled_shape = UnrollNestedTuple(shape); return generate_tuple( [&](auto i) { if constexpr(i.value == 0) @@ -56,10 +40,10 @@ struct Layout else { return TupleReduce([](auto x, auto y) { return x * y; }, - tuple); + unrolled_shape); } }, - Number::Size()>{}); + Number{}); } // Generate LowerDims in Compile-time for MergeTrasform using passed Type @@ -112,8 +96,8 @@ struct Layout // Example shape: (2, (2, 2)), 2, (2, 2) // Unrolled shape: 2, (2, 2), 2, (2, 2) template - __host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple& shape, - const Tuple& idx) + __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple& shape, + const Tuple& idx) { if constexpr(!IsNestedTuple(Tuple{})) { @@ -125,7 +109,7 @@ struct Layout // Iterate over shape tuple elements: // 1. If corresponding idx element is tuple then return (will be unrolled) // 2. If no, pack in tuple. It will be restored during unroll. - auto unrolled_shape_via_idx = generate_tuple( + auto aligned_shape = generate_tuple( [&](auto i) { if constexpr(is_detected>>::value) @@ -140,8 +124,8 @@ struct Layout Number::Size()>{}); // Unroll and process next step - return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx), - UnrollNestedTuple<0, 1>(idx)); + return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape), + UnrollNestedTuple<0, 1>(idx)); } } @@ -150,27 +134,24 @@ struct Layout DescriptorToMerge& desc) { // Reverse each element in tuple - using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape))); - const auto merge_elems = ReversedUnrolledShape{}; - + const auto merge_elems = TupleReverse(UnrollNestedTuple(shape)); // Generate reverted indexes (column major traverse) - using MergeElemsSequence = - typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type; - const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); - const auto upper_dims = make_tuple(Sequence<0>{}); + using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type; + const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); + const auto upper_dims = make_tuple(Sequence<0>{}); // Merge to 1d return transform_tensor_descriptor( desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); } - // Merge nested shape dims + // Merge nested shape dims. Merge nested shape dims when idx is also nested. // Input desc shape: 2, 2, 2, 2, 2, 2 // Example idx: 1, 1, 1, 1 // Example shape: 2, (2, 2), 2, (2, 2) // Merged shape: 2, 4, 2, 4 template - __host__ __device__ constexpr static auto - MakeMerges(const Tuple& shape, const Tuple&, DescriptorToMerge& desc) + __host__ __device__ constexpr static auto CreateMergedDescriptor( + const Tuple& shape, const Tuple&, DescriptorToMerge& desc) { const auto transforms = generate_tuple( [&](auto i) { @@ -224,9 +205,9 @@ struct Layout static_assert(Tuple::Size() == Tuple::Size(), "Idx rank and Shape rank must be the same (except 1d)."); // Unroll while IdxDims is nested - const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx); + const auto aligned_shape = AlignShapeToIdx(shape, idx); // Transform correct form of shape - return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_); + return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_); } } @@ -234,26 +215,21 @@ struct Layout __host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape, const LayoutStrides& strides) { - const auto unrolled_shape = UnrollNestedTuple(shape); - - if constexpr(ck::is_same_v>) - { - // If shape is packed - const auto column_major_packed_strides = - GenerateColumnMajorPackedStrides(unrolled_shape); - return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides); - } - else - { - const auto unrolled_strides = UnrollNestedTuple(strides); - static_assert(unrolled_shape.Size() == unrolled_strides.Size(), - "Size of strides and shape are not consistent."); - return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); - } + const auto unrolled_shape = UnrollNestedTuple(shape); + const auto unrolled_strides = UnrollNestedTuple(strides); + static_assert(unrolled_shape.Size() == unrolled_strides.Size(), + "Size of strides and shape are not consistent."); + return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); } public: - using NaiveDescriptorType = remove_cvref_t; + // If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`. + using DeducedStrides = + std::conditional_t>, + remove_cvref_t, + Strides>; + using NaiveDescriptorType = + remove_cvref_t; /** * \brief Layout constructor. @@ -268,9 +244,9 @@ struct Layout // Construct if runtime mode if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) { - // Keep only shape, strides are not need for transforms shape_ = shape; - descriptor_ = MakeNaiveDescriptor(shape, strides); + strides_ = strides; + descriptor_ = MakeNaiveDescriptor(shape_, strides_); } } @@ -279,7 +255,8 @@ struct Layout if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) { shape_ = shape; - descriptor_ = MakeNaiveDescriptor(shape, Strides{}); + strides_ = GenerateColumnMajorPackedStrides(shape_); + descriptor_ = MakeNaiveDescriptor(shape_, strides_); } } @@ -338,7 +315,7 @@ struct Layout * * \return Calculated size. */ - __host__ __device__ constexpr index_t GetLength() const + __host__ __device__ constexpr index_t GetLengths() const { const auto unrolled_shape = UnrollNestedTuple(shape_); return TupleReduce([](auto x, auto y) { return x * y; }, @@ -346,80 +323,24 @@ struct Layout } /** - * \brief Dimension getter. + * \brief Shape getter. * - * \tparam IDim Dimension idx. - * \return Calculated size. + * \return Shape. */ - template - __host__ __device__ constexpr auto Get() const - { - const auto elem = shape_.At(Number{}); - return elem; - } + __host__ __device__ constexpr Shape GetShape() const { return shape_; } + + /** + * \brief Strides getter. + * + * \return Strides. + */ + __host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; } private: NaiveDescriptorType descriptor_; Shape shape_; + DeducedStrides strides_; }; -// Layout helpers -// Length getter (product if tuple) -template -__host__ __device__ constexpr index_t size(const Layout& layout) -{ - return layout.template GetLength(); -} - -// Get shape size (product of dims if tuple) -template -__host__ __device__ constexpr index_t size(const Tuple& shape) -{ - using UnrolledShape = decltype(UnrollNestedTuple(shape)); - return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; }, - UnrolledShape{}); -} - -// Get dim size (could be returned from get function) -template -__host__ __device__ T constexpr size(const T& dim) -{ - return dim; -} - -// Get layout size (product of shapes) -template -__host__ __device__ constexpr index_t size(const Layout& layout) -{ - return layout.GetLength(); -} - -// Get shape element size -template -__host__ __device__ constexpr index_t size(const Tuple& shape) -{ - return size(shape.At(Number{})); -} - -// Dim getter (tuple if tuple) -template -__host__ __device__ constexpr auto get(const Layout& layout) -{ - return layout.template Get(); -} - -template -__host__ __device__ constexpr Layout make_layout(const Shape& shape, - const Strides& strides) -{ - return Layout(shape, strides); -} - -template -__host__ __device__ constexpr Layout make_layout(const Shape& shape) -{ - return Layout(shape); -} - -} // namespace tensor_transform_wrapper +} // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/layout_utils.hpp b/include/ck/wrapper/layout_utils.hpp new file mode 100644 index 0000000000..fac8f33854 --- /dev/null +++ b/include/ck/wrapper/layout_utils.hpp @@ -0,0 +1,321 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/is_detected.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +namespace ck { +namespace wrapper { + +// Disable from doxygen docs generation +/// @cond +// forward declaration +template > +struct Layout; + +template +using is_tuple = decltype(std::declval().IsTuple()); +/// @endcond + +// make_* +/** + * \brief Make layout function. + * + * \tparam Shape Shape for layout. + * \tparam Strides Strides for layout. + * \return Constructed layout. + */ +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape, + const Strides& strides) +{ + return Layout(shape, strides); +} + +/** + * \brief Make layout function with packed strides + * (column-major). + * + * \tparam Shape Shape for layout. + * \return Constructed layout. + */ +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape) +{ + return Layout(shape); +} + +// Layout helpers +// get +/** + * \brief Get element from tuple (Shape/Strides/Idxs). + * + * \tparam idx Index to lookup. + * \param tuple Tuple to lookup. + * \return Requsted element. + */ +template +__host__ __device__ constexpr auto get(const Tuple& tuple) +{ + return tuple.At(Number{}); +} + +/** + * \brief Get sub layout. + * + * \tparam idx Index to lookup. + * \param layout Layout to create sub layout. + * \return Requsted sub layout. + */ +template +__host__ __device__ constexpr auto get(const Layout& layout) +{ + const auto new_shape = get(layout.GetShape()); + static_assert(is_detected::value, + "Shape of sub layout must be tuple"); + if constexpr(is_same_v>) + { + // If stride not passed, create without strides + return make_layout(new_shape); + } + else + { + const auto new_strides = get(layout.GetStrides()); + static_assert(is_detected::value, + "Strides of sub layout must be tuple"); + return make_layout(new_shape, new_strides); + } +} + +/** + * \brief Hierarchical get. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted element. + */ +template +__host__ __device__ constexpr auto get(const T& elem) +{ + return get(get(elem)); +} + +// size +/** + * \brief Length get (product if tuple). + * + * \tparam idx Index to lookup. + * \param layout Layout to get Shape. + * \return Requsted length. + */ +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.template GetLength(); +} + +/** + * \brief Shape size (product of dims). + * + * \param shape Shape to lookup. + * \return Requsted size. + */ +template +__host__ __device__ constexpr index_t size(const Tuple& shape) +{ + const auto unrolled_shape = UnrollNestedTuple(shape); + return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, + unrolled_shape); +} + +// Get dim size (could be returned from get function) +/** + * \private + */ +template +__host__ __device__ T constexpr size(const T& dim) +{ + return dim; +} + +/** + * \brief Layout size (product of dims). + * + * \param layout Layout to calculate shape size. + * \return Requsted size. + */ +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.GetLengths(); +} + +/** + * \brief Length get from tuple (product if tuple). + * + * \tparam idx Index to lookup. + * \param tuple Tuple to lookup. + * \return Requsted length. + */ +template +__host__ __device__ constexpr index_t size(const Tuple& tuple) +{ + return size(tuple.At(Number{})); +} + +/** + * \brief Hierarchical size. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted element. + */ +template +__host__ __device__ constexpr auto size(const T& elem) +{ + return size(get(elem)); +} + +// rank +/** + * \brief Get layout rank (num elements in shape). + * + * \param layout Layout to calculate rank. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout& layout) +{ + return Shape::Size(); +} + +/** + * \brief Get tuple rank (num elements in tuple). + * Return 1 if scalar passed. + * + * \param tuple Tuple to calculate rank. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple& tuple) +{ + return Tuple::Size(); +} + +/** + * \private + */ +template +__host__ __device__ constexpr index_t rank(const Number&) +{ + return 1; +} + +/** + * \private + */ +__host__ __device__ constexpr index_t rank(const index_t&) { return 1; } + +/** + * \brief Hierarchical rank. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr auto rank(const T& elem) +{ + return rank(get(elem)); +} + +// depth +/** + * \brief Get depth of the layout shape (return 0 if scalar). + * + * \param layout Layout to calculate depth. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr auto depth(const Layout& layout) +{ + return TupleDepth(layout.GetShape()); +} + +/** + * \brief Get depth of the tuple. (return 0 if scalar) + * + * \param tuple Tuple to calculate depth. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr auto depth(const Tuple& tuple) +{ + return TupleDepth(tuple); +} + +/** + * \private + */ +template +__host__ __device__ constexpr index_t depth(const Number&) +{ + return 0; +} + +/** + * \private + */ +__host__ __device__ constexpr index_t depth(const index_t&) { return 0; } + +/** + * \brief Hierarchical depth. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr auto depth(const T& elem) +{ + return depth(get(elem)); +} + +/** + * \brief Get Layout strides. + * + * \param layout Layout to get strides. + * \return Requsted strides. + */ +template +__host__ __device__ constexpr auto stride(const Layout& layout) +{ + return layout.GetStrides(); +} + +/** + * \brief Get Layout shape. + * + * \param layout Layout to get shape. + * \return Requsted shape. + */ +template +__host__ __device__ constexpr auto shape(const Layout& layout) +{ + return layout.GetShape(); +} + +} // namespace wrapper +} // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4aaa5fcfa5..b325a3a7f8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -149,6 +149,7 @@ add_subdirectory(batched_gemm_multi_d) add_subdirectory(grouped_convnd_bwd_data) add_subdirectory(conv_tensor_rearrange) add_subdirectory(transpose) +add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt new file mode 100644 index 0000000000..e25ef176dd --- /dev/null +++ b/test/wrapper/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_layout test_layout.cpp) +target_link_libraries(test_layout PRIVATE utility) diff --git a/test/wrapper/test_layout.cpp b/test/wrapper/test_layout.cpp new file mode 100644 index 0000000000..7d09696fbb --- /dev/null +++ b/test/wrapper/test_layout.cpp @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/wrapper/layout.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +class TestWrapperLayout : public ::testing::Test +{ + protected: + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + + template + void Run(Desc& desc, + Desc1d& desc_1d, + LayoutRuntime& layout_runtime, + LayoutCompiletime& layout_compiletime, + const std::vector& idxs) + { + // 1d check + EXPECT_EQ(desc_1d.GetLength(I0), ck::wrapper::size(layout_runtime)); + // Check layout compiletime and runtime result consistency + EXPECT_EQ(ck::wrapper::size(layout_runtime), ck::wrapper::size(layout_compiletime)); + + for(ck::index_t i = 0; i < desc_1d.GetLength(I0); i++) + { + const ck::index_t layout_runtime_offset_1d = layout_runtime(ck::make_tuple(i)); + const ck::index_t layout_compiletime_offset_1d = layout_compiletime(ck::make_tuple(i)); + const ck::index_t desc_offset_1d = desc_1d.CalculateOffset(ck::make_tuple(i)); + EXPECT_EQ(layout_runtime_offset_1d, desc_offset_1d); + EXPECT_EQ(layout_compiletime_offset_1d, layout_runtime_offset_1d); + } + // size(layout)-d check, don't check if access is hierarchical + if constexpr(!IsNestedTuple(Idxs{})) + { + ck::static_for<0, Idxs::Size(), 1>{}([&](auto d) { + EXPECT_EQ(desc.GetLength(ck::Number{}), ck::wrapper::size(layout_runtime)); + EXPECT_EQ(ck::wrapper::size(layout_runtime), + ck::wrapper::size(layout_compiletime)); + }); + } + for(const auto idx : idxs) + { + const ck::index_t layout_runtime_offset = layout_runtime(idx); + const ck::index_t layout_compiletime_offset = layout_compiletime(idx); + const ck::index_t desc_offset = + desc.CalculateOffset(UnrollNestedTuple(idx)); // Unroll if nested + EXPECT_EQ(layout_runtime_offset, desc_offset); + EXPECT_EQ(layout_runtime_offset, layout_compiletime_offset); + } + } +}; + +TEST_F(TestWrapperLayout, 2d) +{ + // dims:(4, 3) strides:(1, 4) + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s1 = 1; + constexpr ck::index_t s0 = 4; + const auto desc = + ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1))), + ck::make_tuple(ck::Sequence<1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto layout_runtime = ck::wrapper::make_layout(ck::make_tuple(d1, d0)); + const auto layout_compiletime = + ck::wrapper::make_layout(ck::make_tuple(ck::Number{}, ck::Number{})); + std::vector> idxs; + + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs.emplace_back(h, w); + } + } + + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs); +} + +TEST_F(TestWrapperLayout, 3d_nested) +{ + // dims:((2, 3), 4, 3) strides:((2, 4), 12, 48) + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s3 = 2; + constexpr ck::index_t s2 = 4; + constexpr ck::index_t s1 = 12; + constexpr ck::index_t s0 = 48; + const auto desc = ck::make_naive_tensor_descriptor( + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))), + ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_3d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)), + ck::make_pass_through_transform(d1), + ck::make_pass_through_transform(d2)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + const auto layout_runtime = + ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), d1, d0), + ck::make_tuple(ck::make_tuple(s3, s2), s1, s0)); + const auto layout_compiletime = ck::wrapper::make_layout( + ck::make_tuple( + ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}, ck::Number{}), + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::Number{}, + ck::Number{})); + std::vector> idxs_3d; + + for(ck::index_t d = 0; d < d2 * d3; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_3d.emplace_back(d, h, w); + } + } + } + this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d); + + // Check also 4d iteration + std::vector, ck::index_t, ck::index_t>> idxs_4d; + + for(ck::index_t e = 0; e < d3; e++) + { + for(ck::index_t d = 0; d < d2; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_4d.emplace_back(ck::make_tuple(e, d), h, w); + } + } + } + } + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d); +} + +TEST_F(TestWrapperLayout, 2d_nested) +{ + // dims:((2, 3), (4, 3)) strides:((2, 4), (48, 12)) + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s3 = 2; + constexpr ck::index_t s2 = 4; + constexpr ck::index_t s1 = 48; + constexpr ck::index_t s0 = 12; + const auto desc = ck::make_naive_tensor_descriptor( + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))), + ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_2d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)), + ck::make_merge_transform(ck::make_tuple(d0, d1))), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + const auto layout_runtime = + ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), ck::make_tuple(d1, d0)), + ck::make_tuple(ck::make_tuple(s3, s2), ck::make_tuple(s1, s0))); + const auto layout_compiletime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})), + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + std::vector> idxs_2d; + + for(ck::index_t h = 0; h < d2 * d3; h++) + { + for(ck::index_t w = 0; w < d0 * d1; w++) + { + idxs_2d.emplace_back(h, w); + } + } + this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d); + // Check also 4d iteration + std::vector, ck::Tuple>> + idxs_4d; + + for(ck::index_t e = 0; e < d3; e++) + { + for(ck::index_t d = 0; d < d2; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_4d.emplace_back(ck::make_tuple(e, d), ck::make_tuple(h, w)); + } + } + } + } + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d); +} + +TEST_F(TestWrapperLayout, 3d_double_nested) +{ + // dims:(((2, 2), 3), (4, 3)) strides:(((2, 4), 8), (96, 24)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s4 = 2; + constexpr ck::index_t s3 = 4; + constexpr ck::index_t s2 = 8; + constexpr ck::index_t s1 = 96; + constexpr ck::index_t s0 = 24; + const auto desc = ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{}), + ck::make_tuple(ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3, d4))), + ck::make_tuple(ck::Sequence<4, 3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_3d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d3, d4)), + ck::make_pass_through_transform(d2), + ck::make_merge_transform(ck::make_tuple(d0, d1))), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<4, 3>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + const auto desc_2d = transform_tensor_descriptor( + desc_3d, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3 * d4)), + ck::make_pass_through_transform(d1 * d0)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + const auto layout_runtime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)), + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, s3), s2), ck::make_tuple(s1, s0))); + const auto layout_compiletime = ck::wrapper::make_layout( + ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})), + ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + std::vector> idxs_2d; + + for(ck::index_t h = 0; h < d2 * d3 * d4; h++) + { + for(ck::index_t w = 0; w < d0 * d1; w++) + { + idxs_2d.emplace_back(h, w); + } + } + this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d); + // Check also 3d iteration + std::vector, ck::index_t>> idxs_3d; + + for(ck::index_t d = 0; d < d3 * d4; d++) + { + for(ck::index_t h = 0; h < d2; h++) + { + for(ck::index_t w = 0; w < d1 * d0; w++) + { + idxs_3d.emplace_back(ck::make_tuple(d, h), w); + } + } + } + this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d); + // Check also 5d iteration + std::vector, ck::index_t>, + ck::Tuple>> + idxs_5d; + + for(ck::index_t f = 0; f < d4; f++) + { + for(ck::index_t e = 0; e < d3; e++) + { + for(ck::index_t d = 0; d < d2; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_5d.emplace_back(ck::make_tuple(ck::make_tuple(f, e), d), + ck::make_tuple(h, w)); + } + } + } + } + } + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_5d); +} + +TEST(TestLayoutHelpers, SizeAndGet) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + const auto layout_runtime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0))); + const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + + // Size of layout + EXPECT_EQ(ck::wrapper::size(layout_runtime), d4 * d3 * d2 * d1 * d0); + EXPECT_EQ(ck::wrapper::size(layout_compiletime), d4 * d3 * d2 * d1 * d0); + + // Size of dims + EXPECT_EQ(ck::wrapper::size<0>(layout_runtime), d4 * d3 * d2); + EXPECT_EQ(ck::wrapper::size<0>(layout_compiletime), d4 * d3 * d2); + EXPECT_EQ(ck::wrapper::size<1>(layout_runtime), d1 * d0); + EXPECT_EQ(ck::wrapper::size<1>(layout_compiletime), d1 * d0); + + // Access through new layout (using get with layout object) + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_runtime)), d4 * d3); + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_compiletime)), d4 * d3); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2); + + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d4); + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))), + d4); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d3); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))), + d3); + + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2); + + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_runtime)), d1); + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_compiletime)), d1); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_runtime)), d0); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_compiletime)), d0); +} + +TEST(TestLayoutHelpers, DepthAndRank) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + const auto layout_runtime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0))); + const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + + EXPECT_EQ(ck::wrapper::depth(layout_runtime), 3); + EXPECT_EQ(ck::wrapper::depth(layout_compiletime), 3); + EXPECT_EQ(ck::wrapper::depth(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2); + // Check for integer + EXPECT_EQ(ck::wrapper::depth(d0), 0); + + EXPECT_EQ(ck::wrapper::rank(layout_runtime), 2); + EXPECT_EQ(ck::wrapper::rank(layout_compiletime), 2); + EXPECT_EQ(ck::wrapper::rank(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2); + // Check for integer + EXPECT_EQ(ck::wrapper::rank(d0), 1); +} + +TEST(TestLayoutHelpers, ShapeAndStrides) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s4 = 2; + constexpr ck::index_t s3 = 4; + constexpr ck::index_t s2 = 8; + constexpr ck::index_t s1 = 96; + constexpr ck::index_t s0 = 24; + const auto shape_compiletime = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); + const auto strides_compiletime = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); + const auto shape_runtime = + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)); + const auto strides_runtime = + ck::make_tuple(ck::make_tuple(ck::make_tuple(s4, s3), s2), ck::make_tuple(s1, s0)); + const auto layout_runtime = ck::wrapper::make_layout(shape_runtime, strides_runtime); + const auto layout_compiletime = + ck::wrapper::make_layout(shape_compiletime, strides_compiletime); + + constexpr bool check_compiletime_shape = + std::is_same_v::type, + decltype(shape(layout_compiletime))>; + constexpr bool check_compiletime_strides = + std::is_same_v::type, + decltype(stride(layout_compiletime))>; + constexpr bool check_runtime_shape = + std::is_same_v::type, + decltype(shape(layout_runtime))>; + constexpr bool check_runtime_strides = + std::is_same_v::type, + decltype(stride(layout_runtime))>; + EXPECT_TRUE(check_compiletime_shape); + EXPECT_TRUE(check_compiletime_strides); + EXPECT_TRUE(check_runtime_shape); + EXPECT_TRUE(check_runtime_strides); +} + +TEST(TestLayoutHelpers, Hierarchical) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + const auto runtime_shape = + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)); + const auto layout_runtime = ck::wrapper::make_layout(runtime_shape); + const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + + EXPECT_EQ((ck::wrapper::rank<0, 0>(runtime_shape)), 2); + EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_runtime)), 2); + EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_compiletime)), 2); + + EXPECT_EQ((ck::wrapper::depth<0, 0>(runtime_shape)), 1); + EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_runtime)), 1); + EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_compiletime)), 1); + + EXPECT_EQ((ck::wrapper::size<0, 0>(runtime_shape)), d4 * d3); + EXPECT_EQ((ck::wrapper::size<0, 0>(layout_runtime)), d4 * d3); + EXPECT_EQ((ck::wrapper::size<0, 0>(layout_compiletime)), d4 * d3); + + EXPECT_EQ((ck::wrapper::get<0, 0, 0>(runtime_shape)), d4); +}