mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Introduce wrapper library (#1071)
* Introduce wrapper library * Update cmake files * Revert "Update cmake files" This reverts commitc27f88b565. * Fix comments [ROCm/composable_kernel commit:836b7e557d]
This commit is contained in:
@@ -19,6 +19,7 @@ None
|
||||
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
|
||||
- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
|
||||
- Support for Batched Gemm DL (#732)
|
||||
- Introduce wrapper sublibrary (limited functionality) (#1071)
|
||||
|
||||
### Changes
|
||||
- Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
|
||||
|
||||
4
client_example/25_tensor_transforms/CMakeLists.txt
Normal file
4
client_example/25_tensor_transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
add_executable(client_tensor_transform tensor_transform.cpp)
|
||||
target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
|
||||
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
#include "tensor_transform_wrapper.hpp"
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
|
||||
using DataType = int;
|
||||
|
||||
@@ -17,7 +17,7 @@ template <typename Layout>
|
||||
void Print1d(const Layout& layout)
|
||||
{
|
||||
std::cout << "Print1d" << std::endl;
|
||||
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++)
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(w)) << " ";
|
||||
}
|
||||
@@ -28,9 +28,9 @@ template <typename Layout>
|
||||
void Print2d(const Layout& layout)
|
||||
{
|
||||
std::cout << "Print2d" << std::endl;
|
||||
for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++)
|
||||
for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++)
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(h, w)) << " ";
|
||||
}
|
||||
@@ -43,15 +43,11 @@ template <typename Layout>
|
||||
void Print3dCustom(const Layout& layout)
|
||||
{
|
||||
std::cout << "Print3dCustom" << std::endl;
|
||||
for(ck::index_t d = 0;
|
||||
d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout));
|
||||
d++)
|
||||
for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++)
|
||||
{
|
||||
for(ck::index_t h = 0;
|
||||
h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout));
|
||||
h++)
|
||||
for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++)
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " ";
|
||||
}
|
||||
@@ -68,7 +64,7 @@ int main()
|
||||
// Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor)
|
||||
// (dims:4,8 strides:1,4)
|
||||
const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{});
|
||||
const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8);
|
||||
const auto layout_4x8_s1x4 = ck::wrapper::make_layout(shape_4x8);
|
||||
std::cout << "dims:4,8 strides:1,4" << std::endl;
|
||||
Print2d(layout_4x8_s1x4);
|
||||
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
|
||||
@@ -77,10 +73,9 @@ int main()
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor)
|
||||
// dims:4,(2,4) strides:2,(1,8)
|
||||
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
|
||||
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
|
||||
const auto layout_4x2x4_s2x1x8 =
|
||||
ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
|
||||
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
|
||||
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
|
||||
const auto layout_4x2x4_s2x1x8 = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
|
||||
|
||||
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
|
||||
Print2d(layout_4x2x4_s2x1x8);
|
||||
@@ -92,7 +87,7 @@ int main()
|
||||
const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}),
|
||||
ck::make_tuple(ck::Number<2>{}, ck::Number<8>{}));
|
||||
static const auto layout_2x2x2x4_s1x4x2x8 =
|
||||
ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8);
|
||||
ck::wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8);
|
||||
|
||||
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
|
||||
Print2d(layout_2x2x2x4_s1x4x2x8);
|
||||
@@ -108,7 +103,7 @@ int main()
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}),
|
||||
ck::Number<8>{});
|
||||
static const auto layout_2x2x2x4_s1x4x2x8_nested =
|
||||
ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested);
|
||||
ck::wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested);
|
||||
|
||||
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
|
||||
Print1d(layout_2x2x2x4_s1x4x2x8_nested);
|
||||
@@ -778,7 +778,9 @@ WARN_LOGFILE =
|
||||
INPUT = ../../include/ck/tensor_operation/gpu/grid \
|
||||
../../include/ck/tensor_operation/gpu/block \
|
||||
../../include/ck/tensor_operation/gpu/thread \
|
||||
../../library/include/ck/library/utility
|
||||
../../library/include/ck/library/utility \
|
||||
../../include/ck/wrapper
|
||||
|
||||
|
||||
# This tag can be used to specify the character encoding of the source files
|
||||
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
|
||||
|
||||
@@ -34,6 +34,7 @@ Current CK library are structured into 4 layers:
|
||||
* "Templated Tile Operators" layer
|
||||
* "Templated Kernel and Invoker" layer
|
||||
* "Instantiated Kernel and Invoker" layer
|
||||
* "Wrapper for tensor transform operations"
|
||||
* "Client API" layer
|
||||
|
||||
.. image:: data/ck_layer.png
|
||||
@@ -50,6 +51,7 @@ The following is a list of CK documents in the suggested reading order:
|
||||
|
||||
tutorial_hello_world
|
||||
dockerhub
|
||||
wrapper
|
||||
Supported_Primitives_Guide
|
||||
API_Reference_Guide
|
||||
Contributors_Guide
|
||||
|
||||
54
docs/wrapper.rst
Normal file
54
docs/wrapper.rst
Normal file
@@ -0,0 +1,54 @@
|
||||
===============
|
||||
Wrapper
|
||||
===============
|
||||
|
||||
-------------------------------------
|
||||
Description
|
||||
-------------------------------------
|
||||
|
||||
.. note::
|
||||
|
||||
The wrapper is under development and its functionality is limited.
|
||||
|
||||
|
||||
CK provides a lightweight wrapper for more complex operations implemented in
|
||||
the library. It allows indexing of nested layouts using a simple interface
|
||||
(avoiding complex descriptor transformations).
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: c
|
||||
|
||||
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
|
||||
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
|
||||
const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
|
||||
|
||||
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
|
||||
for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(h, w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
Output::
|
||||
|
||||
dims:4,(2,4) strides:2,(1,8)
|
||||
0 1 8 9 16 17 24 25
|
||||
2 3 10 11 18 19 26 27
|
||||
4 5 12 13 20 21 28 29
|
||||
6 7 14 15 22 23 30 31
|
||||
|
||||
-------------------------------------
|
||||
Layout
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenstruct:: ck::wrapper::Layout
|
||||
|
||||
-------------------------------------
|
||||
Layout helpers
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenfile:: layout_utils.hpp
|
||||
@@ -1,2 +0,0 @@
|
||||
add_example_executable(example_tensor_transform tensor_transform.cpp)
|
||||
add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
|
||||
@@ -166,4 +166,16 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
|
||||
return (is_detected<is_tuple, Ts>::value || ...);
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename T>
|
||||
__host__ __device__ constexpr auto TupleDepth(const T&)
|
||||
{
|
||||
return depth;
|
||||
}
|
||||
|
||||
template <index_t depth = 0, typename... Ts>
|
||||
__host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
|
||||
{
|
||||
return math::max(TupleDepth<depth + 1>(Ts{})...);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -3,27 +3,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/sequence_helper.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/wrapper/layout_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_transform_wrapper {
|
||||
namespace wrapper {
|
||||
|
||||
/**
|
||||
* \brief Layout wrapper
|
||||
*
|
||||
* \details
|
||||
* Layout wrapper that performs the tensor descriptor logic.
|
||||
* \brief Layout wrapper that performs the tensor descriptor logic.
|
||||
*
|
||||
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
|
||||
* (dynamic layout). It is possible to pass nested shapes
|
||||
@@ -32,21 +18,19 @@ namespace tensor_transform_wrapper {
|
||||
* (dynamic layout). Stride tuple should be nested if shape tuple is
|
||||
* nested.
|
||||
*/
|
||||
template <typename Shape, typename Strides = Tuple<>>
|
||||
template <typename Shape, typename Strides>
|
||||
struct Layout
|
||||
{
|
||||
private:
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
// Generate packed (column-major) strides if not passed
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& tuple)
|
||||
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == 0)
|
||||
@@ -56,10 +40,10 @@ struct Layout
|
||||
else
|
||||
{
|
||||
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
|
||||
tuple);
|
||||
unrolled_shape);
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
Number<decltype(unrolled_shape)::Size()>{});
|
||||
}
|
||||
|
||||
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
|
||||
@@ -112,8 +96,8 @@ struct Layout
|
||||
// Example shape: (2, (2, 2)), 2, (2, 2)
|
||||
// Unrolled shape: 2, (2, 2), 2, (2, 2)
|
||||
template <typename... ShapeDims, typename... IdxDims>
|
||||
__host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx)
|
||||
__host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx)
|
||||
{
|
||||
if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
|
||||
{
|
||||
@@ -125,7 +109,7 @@ struct Layout
|
||||
// Iterate over shape tuple elements:
|
||||
// 1. If corresponding idx element is tuple then return (will be unrolled)
|
||||
// 2. If no, pack in tuple. It will be restored during unroll.
|
||||
auto unrolled_shape_via_idx = generate_tuple(
|
||||
auto aligned_shape = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(is_detected<is_tuple,
|
||||
tuple_element_t<i, Tuple<IdxDims...>>>::value)
|
||||
@@ -140,8 +124,8 @@ struct Layout
|
||||
Number<Tuple<IdxDims...>::Size()>{});
|
||||
|
||||
// Unroll and process next step
|
||||
return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx),
|
||||
UnrollNestedTuple<0, 1>(idx));
|
||||
return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
|
||||
UnrollNestedTuple<0, 1>(idx));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,27 +134,24 @@ struct Layout
|
||||
DescriptorToMerge& desc)
|
||||
{
|
||||
// Reverse each element in tuple
|
||||
using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape)));
|
||||
const auto merge_elems = ReversedUnrolledShape{};
|
||||
|
||||
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
|
||||
// Generate reverted indexes (column major traverse)
|
||||
using MergeElemsSequence =
|
||||
typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type;
|
||||
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
|
||||
const auto upper_dims = make_tuple(Sequence<0>{});
|
||||
using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
|
||||
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
|
||||
const auto upper_dims = make_tuple(Sequence<0>{});
|
||||
// Merge to 1d
|
||||
return transform_tensor_descriptor(
|
||||
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
// Merge nested shape dims
|
||||
// Merge nested shape dims. Merge nested shape dims when idx is also nested.
|
||||
// Input desc shape: 2, 2, 2, 2, 2, 2
|
||||
// Example idx: 1, 1, 1, 1
|
||||
// Example shape: 2, (2, 2), 2, (2, 2)
|
||||
// Merged shape: 2, 4, 2, 4
|
||||
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
|
||||
__host__ __device__ constexpr static auto
|
||||
MakeMerges(const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
|
||||
__host__ __device__ constexpr static auto CreateMergedDescriptor(
|
||||
const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
|
||||
{
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -224,9 +205,9 @@ struct Layout
|
||||
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
|
||||
"Idx rank and Shape rank must be the same (except 1d).");
|
||||
// Unroll while IdxDims is nested
|
||||
const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx);
|
||||
const auto aligned_shape = AlignShapeToIdx(shape, idx);
|
||||
// Transform correct form of shape
|
||||
return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_);
|
||||
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,26 +215,21 @@ struct Layout
|
||||
__host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
|
||||
if constexpr(ck::is_same_v<LayoutStrides, Tuple<>>)
|
||||
{
|
||||
// If shape is packed
|
||||
const auto column_major_packed_strides =
|
||||
GenerateColumnMajorPackedStrides(unrolled_shape);
|
||||
return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
|
||||
public:
|
||||
using NaiveDescriptorType = remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, Strides{}))>;
|
||||
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
|
||||
using DeducedStrides =
|
||||
std::conditional_t<is_same_v<Strides, Tuple<>>,
|
||||
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
|
||||
Strides>;
|
||||
using NaiveDescriptorType =
|
||||
remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, DeducedStrides{}))>;
|
||||
|
||||
/**
|
||||
* \brief Layout constructor.
|
||||
@@ -268,9 +244,9 @@ struct Layout
|
||||
// Construct if runtime mode
|
||||
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
// Keep only shape, strides are not need for transforms
|
||||
shape_ = shape;
|
||||
descriptor_ = MakeNaiveDescriptor(shape, strides);
|
||||
strides_ = strides;
|
||||
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,7 +255,8 @@ struct Layout
|
||||
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
shape_ = shape;
|
||||
descriptor_ = MakeNaiveDescriptor(shape, Strides{});
|
||||
strides_ = GenerateColumnMajorPackedStrides(shape_);
|
||||
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,7 +315,7 @@ struct Layout
|
||||
*
|
||||
* \return Calculated size.
|
||||
*/
|
||||
__host__ __device__ constexpr index_t GetLength() const
|
||||
__host__ __device__ constexpr index_t GetLengths() const
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape_);
|
||||
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
|
||||
@@ -346,80 +323,24 @@ struct Layout
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Dimension getter.
|
||||
* \brief Shape getter.
|
||||
*
|
||||
* \tparam IDim Dimension idx.
|
||||
* \return Calculated size.
|
||||
* \return Shape.
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr auto Get() const
|
||||
{
|
||||
const auto elem = shape_.At(Number<IDim>{});
|
||||
return elem;
|
||||
}
|
||||
__host__ __device__ constexpr Shape GetShape() const { return shape_; }
|
||||
|
||||
/**
|
||||
* \brief Strides getter.
|
||||
*
|
||||
* \return Strides.
|
||||
*/
|
||||
__host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; }
|
||||
|
||||
private:
|
||||
NaiveDescriptorType descriptor_;
|
||||
Shape shape_;
|
||||
DeducedStrides strides_;
|
||||
};
|
||||
|
||||
// Layout helpers
|
||||
// Length getter (product if tuple)
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.template GetLength<idx>();
|
||||
}
|
||||
|
||||
// Get shape size (product of dims if tuple)
|
||||
template <typename... ShapeDims>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
{
|
||||
using UnrolledShape = decltype(UnrollNestedTuple(shape));
|
||||
return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; },
|
||||
UnrolledShape{});
|
||||
}
|
||||
|
||||
// Get dim size (could be returned from get function)
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr size(const T& dim)
|
||||
{
|
||||
return dim;
|
||||
}
|
||||
|
||||
// Get layout size (product of shapes)
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetLength();
|
||||
}
|
||||
|
||||
// Get shape element size
|
||||
template <index_t idx, typename... ShapeDims>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
{
|
||||
return size(shape.At(Number<idx>{}));
|
||||
}
|
||||
|
||||
// Dim getter (tuple if tuple)
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.template Get<idx>();
|
||||
}
|
||||
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
|
||||
const Strides& strides)
|
||||
{
|
||||
return Layout<Shape, Strides>(shape, strides);
|
||||
}
|
||||
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr Layout<Shape> make_layout(const Shape& shape)
|
||||
{
|
||||
return Layout<Shape>(shape);
|
||||
}
|
||||
|
||||
} // namespace tensor_transform_wrapper
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
321
include/ck/wrapper/layout_utils.hpp
Normal file
321
include/ck/wrapper/layout_utils.hpp
Normal file
@@ -0,0 +1,321 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/sequence_helper.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declaration
|
||||
template <typename Shape, typename Strides = Tuple<>>
|
||||
struct Layout;
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
/// @endcond
|
||||
|
||||
// make_*
|
||||
/**
|
||||
* \brief Make layout function.
|
||||
*
|
||||
* \tparam Shape Shape for layout.
|
||||
* \tparam Strides Strides for layout.
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& shape,
|
||||
const Strides& strides)
|
||||
{
|
||||
return Layout<Shape, Strides>(shape, strides);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Make layout function with packed strides
|
||||
* (column-major).
|
||||
*
|
||||
* \tparam Shape Shape for layout.
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr Layout<Shape> make_layout(const Shape& shape)
|
||||
{
|
||||
return Layout<Shape>(shape);
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
// get
|
||||
/**
|
||||
* \brief Get element from tuple (Shape/Strides/Idxs).
|
||||
*
|
||||
* \tparam idx Index to lookup.
|
||||
* \param tuple Tuple to lookup.
|
||||
* \return Requsted element.
|
||||
*/
|
||||
template <index_t idx, typename... Dims>
|
||||
__host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
|
||||
{
|
||||
return tuple.At(Number<idx>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get sub layout.
|
||||
*
|
||||
* \tparam idx Index to lookup.
|
||||
* \param layout Layout to create sub layout.
|
||||
* \return Requsted sub layout.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
const auto new_shape = get<idx>(layout.GetShape());
|
||||
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
|
||||
"Shape of sub layout must be tuple");
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
// If stride not passed, create without strides
|
||||
return make_layout(new_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto new_strides = get<idx>(layout.GetStrides());
|
||||
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
|
||||
"Strides of sub layout must be tuple");
|
||||
return make_layout(new_shape, new_strides);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Hierarchical get.
|
||||
*
|
||||
* \tparam Idxs Indexes to lookup.
|
||||
* \param elem Element to lookup.
|
||||
* \return Requsted element.
|
||||
*/
|
||||
template <index_t Idx, index_t... Idxs, typename T>
|
||||
__host__ __device__ constexpr auto get(const T& elem)
|
||||
{
|
||||
return get<Idxs...>(get<Idx>(elem));
|
||||
}
|
||||
|
||||
// size
|
||||
/**
|
||||
* \brief Length get (product if tuple).
|
||||
*
|
||||
* \tparam idx Index to lookup.
|
||||
* \param layout Layout to get Shape.
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.template GetLength<idx>();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Shape size (product of dims).
|
||||
*
|
||||
* \param shape Shape to lookup.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename... ShapeDims>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
|
||||
unrolled_shape);
|
||||
}
|
||||
|
||||
// Get dim size (could be returned from get function)
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr size(const T& dim)
|
||||
{
|
||||
return dim;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Layout size (product of dims).
|
||||
*
|
||||
* \param layout Layout to calculate shape size.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetLengths();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Length get from tuple (product if tuple).
|
||||
*
|
||||
* \tparam idx Index to lookup.
|
||||
* \param tuple Tuple to lookup.
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename... Ts>
|
||||
__host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple)
|
||||
{
|
||||
return size(tuple.At(Number<idx>{}));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Hierarchical size.
|
||||
*
|
||||
* \tparam Idxs Indexes to lookup.
|
||||
* \param elem Element to lookup.
|
||||
* \return Requsted element.
|
||||
*/
|
||||
template <index_t... Idxs, typename T>
|
||||
__host__ __device__ constexpr auto size(const T& elem)
|
||||
{
|
||||
return size(get<Idxs...>(elem));
|
||||
}
|
||||
|
||||
// rank
|
||||
/**
|
||||
* \brief Get layout rank (num elements in shape).
|
||||
*
|
||||
* \param layout Layout to calculate rank.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return Shape::Size();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get tuple rank (num elements in tuple).
|
||||
* Return 1 if scalar passed.
|
||||
*
|
||||
* \param tuple Tuple to calculate rank.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <typename... Dims>
|
||||
__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& tuple)
|
||||
{
|
||||
return Tuple<Dims...>::Size();
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr index_t rank(const Number<IDim>&)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
__host__ __device__ constexpr index_t rank(const index_t&) { return 1; }
|
||||
|
||||
/**
|
||||
* \brief Hierarchical rank.
|
||||
*
|
||||
* \tparam Idxs Indexes to lookup.
|
||||
* \param elem Element to lookup.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <index_t... Idxs, typename T>
|
||||
__host__ __device__ constexpr auto rank(const T& elem)
|
||||
{
|
||||
return rank(get<Idxs...>(elem));
|
||||
}
|
||||
|
||||
// depth
|
||||
/**
|
||||
* \brief Get depth of the layout shape (return 0 if scalar).
|
||||
*
|
||||
* \param layout Layout to calculate depth.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return TupleDepth(layout.GetShape());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get depth of the tuple. (return 0 if scalar)
|
||||
*
|
||||
* \param tuple Tuple to calculate depth.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <typename... Dims>
|
||||
__host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
|
||||
{
|
||||
return TupleDepth(tuple);
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr index_t depth(const Number<IDim>&)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
__host__ __device__ constexpr index_t depth(const index_t&) { return 0; }
|
||||
|
||||
/**
|
||||
* \brief Hierarchical depth.
|
||||
*
|
||||
* \tparam Idxs Indexes to lookup.
|
||||
* \param elem Element to lookup.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <index_t... Idxs, typename T>
|
||||
__host__ __device__ constexpr auto depth(const T& elem)
|
||||
{
|
||||
return depth(get<Idxs...>(elem));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout strides.
|
||||
*
|
||||
* \param layout Layout to get strides.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto stride(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetStrides();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Layout shape.
|
||||
*
|
||||
* \param layout Layout to get shape.
|
||||
* \return Requsted shape.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto shape(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetShape();
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -149,6 +149,7 @@ add_subdirectory(batched_gemm_multi_d)
|
||||
add_subdirectory(grouped_convnd_bwd_data)
|
||||
add_subdirectory(conv_tensor_rearrange)
|
||||
add_subdirectory(transpose)
|
||||
add_subdirectory(wrapper)
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_subdirectory(wmma_op)
|
||||
endif()
|
||||
|
||||
2
test/wrapper/CMakeLists.txt
Normal file
2
test/wrapper/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_gtest_executable(test_layout test_layout.cpp)
|
||||
target_link_libraries(test_layout PRIVATE utility)
|
||||
481
test/wrapper/test_layout.cpp
Normal file
481
test/wrapper/test_layout.cpp
Normal file
@@ -0,0 +1,481 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
|
||||
class TestWrapperLayout : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
static constexpr auto I0 = ck::Number<0>{};
|
||||
static constexpr auto I1 = ck::Number<1>{};
|
||||
|
||||
template <typename Desc,
|
||||
typename Desc1d,
|
||||
typename LayoutRuntime,
|
||||
typename LayoutCompiletime,
|
||||
typename Idxs>
|
||||
void Run(Desc& desc,
|
||||
Desc1d& desc_1d,
|
||||
LayoutRuntime& layout_runtime,
|
||||
LayoutCompiletime& layout_compiletime,
|
||||
const std::vector<Idxs>& idxs)
|
||||
{
|
||||
// 1d check
|
||||
EXPECT_EQ(desc_1d.GetLength(I0), ck::wrapper::size(layout_runtime));
|
||||
// Check layout compiletime and runtime result consistency
|
||||
EXPECT_EQ(ck::wrapper::size(layout_runtime), ck::wrapper::size(layout_compiletime));
|
||||
|
||||
for(ck::index_t i = 0; i < desc_1d.GetLength(I0); i++)
|
||||
{
|
||||
const ck::index_t layout_runtime_offset_1d = layout_runtime(ck::make_tuple(i));
|
||||
const ck::index_t layout_compiletime_offset_1d = layout_compiletime(ck::make_tuple(i));
|
||||
const ck::index_t desc_offset_1d = desc_1d.CalculateOffset(ck::make_tuple(i));
|
||||
EXPECT_EQ(layout_runtime_offset_1d, desc_offset_1d);
|
||||
EXPECT_EQ(layout_compiletime_offset_1d, layout_runtime_offset_1d);
|
||||
}
|
||||
// size(layout)-d check, don't check if access is hierarchical
|
||||
if constexpr(!IsNestedTuple(Idxs{}))
|
||||
{
|
||||
ck::static_for<0, Idxs::Size(), 1>{}([&](auto d) {
|
||||
EXPECT_EQ(desc.GetLength(ck::Number<d>{}), ck::wrapper::size<d>(layout_runtime));
|
||||
EXPECT_EQ(ck::wrapper::size<d>(layout_runtime),
|
||||
ck::wrapper::size<d>(layout_compiletime));
|
||||
});
|
||||
}
|
||||
for(const auto idx : idxs)
|
||||
{
|
||||
const ck::index_t layout_runtime_offset = layout_runtime(idx);
|
||||
const ck::index_t layout_compiletime_offset = layout_compiletime(idx);
|
||||
const ck::index_t desc_offset =
|
||||
desc.CalculateOffset(UnrollNestedTuple(idx)); // Unroll if nested
|
||||
EXPECT_EQ(layout_runtime_offset, desc_offset);
|
||||
EXPECT_EQ(layout_runtime_offset, layout_compiletime_offset);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestWrapperLayout, 2d)
|
||||
{
|
||||
// dims:(4, 3) strides:(1, 4)
|
||||
constexpr ck::index_t d1 = 4;
|
||||
constexpr ck::index_t d0 = 3;
|
||||
constexpr ck::index_t s1 = 1;
|
||||
constexpr ck::index_t s0 = 4;
|
||||
const auto desc =
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{}),
|
||||
ck::make_tuple(ck::Number<s1>{}, ck::Number<s0>{}));
|
||||
// Reverse due to column major
|
||||
const auto desc_1d = transform_tensor_descriptor(
|
||||
desc,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1))),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}));
|
||||
const auto layout_runtime = ck::wrapper::make_layout(ck::make_tuple(d1, d0));
|
||||
const auto layout_compiletime =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{}));
|
||||
std::vector<ck::Tuple<ck::index_t, ck::index_t>> idxs;
|
||||
|
||||
for(ck::index_t h = 0; h < d1; h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < d0; w++)
|
||||
{
|
||||
idxs.emplace_back(h, w);
|
||||
}
|
||||
}
|
||||
|
||||
this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs);
|
||||
}
|
||||
|
||||
TEST_F(TestWrapperLayout, 3d_nested)
|
||||
{
|
||||
// dims:((2, 3), 4, 3) strides:((2, 4), 12, 48)
|
||||
constexpr ck::index_t d3 = 2;
|
||||
constexpr ck::index_t d2 = 3;
|
||||
constexpr ck::index_t d1 = 4;
|
||||
constexpr ck::index_t d0 = 3;
|
||||
constexpr ck::index_t s3 = 2;
|
||||
constexpr ck::index_t s2 = 4;
|
||||
constexpr ck::index_t s1 = 12;
|
||||
constexpr ck::index_t s0 = 48;
|
||||
const auto desc = ck::make_naive_tensor_descriptor(
|
||||
ck::make_tuple(ck::Number<d3>{}, ck::Number<d2>{}, ck::Number<d1>{}, ck::Number<d0>{}),
|
||||
ck::make_tuple(ck::Number<s3>{}, ck::Number<s2>{}, ck::Number<s1>{}, ck::Number<s0>{}));
|
||||
// Reverse due to column major
|
||||
const auto desc_1d = transform_tensor_descriptor(
|
||||
desc,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))),
|
||||
ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}));
|
||||
const auto desc_3d = transform_tensor_descriptor(
|
||||
desc,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)),
|
||||
ck::make_pass_through_transform(d1),
|
||||
ck::make_pass_through_transform(d2)),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
|
||||
const auto layout_runtime =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), d1, d0),
|
||||
ck::make_tuple(ck::make_tuple(s3, s2), s1, s0));
|
||||
const auto layout_compiletime = ck::wrapper::make_layout(
|
||||
ck::make_tuple(
|
||||
ck::make_tuple(ck::Number<d3>{}, ck::Number<d2>{}), ck::Number<d1>{}, ck::Number<d0>{}),
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<s3>{}, ck::Number<s2>{}),
|
||||
ck::Number<s1>{},
|
||||
ck::Number<s0>{}));
|
||||
std::vector<ck::Tuple<ck::index_t, ck::index_t, ck::index_t>> idxs_3d;
|
||||
|
||||
for(ck::index_t d = 0; d < d2 * d3; d++)
|
||||
{
|
||||
for(ck::index_t h = 0; h < d1; h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < d0; w++)
|
||||
{
|
||||
idxs_3d.emplace_back(d, h, w);
|
||||
}
|
||||
}
|
||||
}
|
||||
this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d);
|
||||
|
||||
// Check also 4d iteration
|
||||
std::vector<ck::Tuple<ck::Tuple<ck::index_t, ck::index_t>, ck::index_t, ck::index_t>> idxs_4d;
|
||||
|
||||
for(ck::index_t e = 0; e < d3; e++)
|
||||
{
|
||||
for(ck::index_t d = 0; d < d2; d++)
|
||||
{
|
||||
for(ck::index_t h = 0; h < d1; h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < d0; w++)
|
||||
{
|
||||
idxs_4d.emplace_back(ck::make_tuple(e, d), h, w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d);
|
||||
}
|
||||
|
||||
TEST_F(TestWrapperLayout, 2d_nested)
|
||||
{
|
||||
// dims:((2, 3), (4, 3)) strides:((2, 4), (48, 12))
|
||||
constexpr ck::index_t d3 = 2;
|
||||
constexpr ck::index_t d2 = 3;
|
||||
constexpr ck::index_t d1 = 4;
|
||||
constexpr ck::index_t d0 = 3;
|
||||
constexpr ck::index_t s3 = 2;
|
||||
constexpr ck::index_t s2 = 4;
|
||||
constexpr ck::index_t s1 = 48;
|
||||
constexpr ck::index_t s0 = 12;
|
||||
const auto desc = ck::make_naive_tensor_descriptor(
|
||||
ck::make_tuple(ck::Number<d3>{}, ck::Number<d2>{}, ck::Number<d1>{}, ck::Number<d0>{}),
|
||||
ck::make_tuple(ck::Number<s3>{}, ck::Number<s2>{}, ck::Number<s1>{}, ck::Number<s0>{}));
|
||||
// Reverse due to column major
|
||||
const auto desc_1d = transform_tensor_descriptor(
|
||||
desc,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))),
|
||||
ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}));
|
||||
const auto desc_2d = transform_tensor_descriptor(
|
||||
desc,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)),
|
||||
ck::make_merge_transform(ck::make_tuple(d0, d1))),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
|
||||
const auto layout_runtime =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), ck::make_tuple(d1, d0)),
|
||||
ck::make_tuple(ck::make_tuple(s3, s2), ck::make_tuple(s1, s0)));
|
||||
const auto layout_compiletime = ck::wrapper::make_layout(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<d3>{}, ck::Number<d2>{}),
|
||||
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})),
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<s3>{}, ck::Number<s2>{}),
|
||||
ck::make_tuple(ck::Number<s1>{}, ck::Number<s0>{})));
|
||||
std::vector<ck::Tuple<ck::index_t, ck::index_t>> idxs_2d;
|
||||
|
||||
for(ck::index_t h = 0; h < d2 * d3; h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < d0 * d1; w++)
|
||||
{
|
||||
idxs_2d.emplace_back(h, w);
|
||||
}
|
||||
}
|
||||
this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d);
|
||||
// Check also 4d iteration
|
||||
std::vector<ck::Tuple<ck::Tuple<ck::index_t, ck::index_t>, ck::Tuple<ck::index_t, ck::index_t>>>
|
||||
idxs_4d;
|
||||
|
||||
for(ck::index_t e = 0; e < d3; e++)
|
||||
{
|
||||
for(ck::index_t d = 0; d < d2; d++)
|
||||
{
|
||||
for(ck::index_t h = 0; h < d1; h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < d0; w++)
|
||||
{
|
||||
idxs_4d.emplace_back(ck::make_tuple(e, d), ck::make_tuple(h, w));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d);
|
||||
}
|
||||
|
||||
TEST_F(TestWrapperLayout, 3d_double_nested)
|
||||
{
|
||||
// dims:(((2, 2), 3), (4, 3)) strides:(((2, 4), 8), (96, 24))
|
||||
constexpr ck::index_t d4 = 2;
|
||||
constexpr ck::index_t d3 = 2;
|
||||
constexpr ck::index_t d2 = 3;
|
||||
constexpr ck::index_t d1 = 4;
|
||||
constexpr ck::index_t d0 = 3;
|
||||
constexpr ck::index_t s4 = 2;
|
||||
constexpr ck::index_t s3 = 4;
|
||||
constexpr ck::index_t s2 = 8;
|
||||
constexpr ck::index_t s1 = 96;
|
||||
constexpr ck::index_t s0 = 24;
|
||||
const auto desc = ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<d4>{},
|
||||
ck::Number<d3>{},
|
||||
ck::Number<d2>{},
|
||||
ck::Number<d1>{},
|
||||
ck::Number<d0>{}),
|
||||
ck::make_tuple(ck::Number<s4>{},
|
||||
ck::Number<s3>{},
|
||||
ck::Number<s2>{},
|
||||
ck::Number<s1>{},
|
||||
ck::Number<s0>{}));
|
||||
// Reverse due to column major
|
||||
const auto desc_1d = transform_tensor_descriptor(
|
||||
desc,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3, d4))),
|
||||
ck::make_tuple(ck::Sequence<4, 3, 2, 1, 0>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}));
|
||||
const auto desc_3d = transform_tensor_descriptor(
|
||||
desc,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d3, d4)),
|
||||
ck::make_pass_through_transform(d2),
|
||||
ck::make_merge_transform(ck::make_tuple(d0, d1))),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<4, 3>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
|
||||
const auto desc_2d = transform_tensor_descriptor(
|
||||
desc_3d,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3 * d4)),
|
||||
ck::make_pass_through_transform(d1 * d0)),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
|
||||
const auto layout_runtime = ck::wrapper::make_layout(
|
||||
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)),
|
||||
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, s3), s2), ck::make_tuple(s1, s0)));
|
||||
const auto layout_compiletime = ck::wrapper::make_layout(
|
||||
ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
|
||||
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})),
|
||||
ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<s3>{}), ck::Number<s2>{}),
|
||||
ck::make_tuple(ck::Number<s1>{}, ck::Number<s0>{})));
|
||||
std::vector<ck::Tuple<ck::index_t, ck::index_t>> idxs_2d;
|
||||
|
||||
for(ck::index_t h = 0; h < d2 * d3 * d4; h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < d0 * d1; w++)
|
||||
{
|
||||
idxs_2d.emplace_back(h, w);
|
||||
}
|
||||
}
|
||||
this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d);
|
||||
// Check also 3d iteration
|
||||
std::vector<ck::Tuple<ck::Tuple<ck::index_t, ck::index_t>, ck::index_t>> idxs_3d;
|
||||
|
||||
for(ck::index_t d = 0; d < d3 * d4; d++)
|
||||
{
|
||||
for(ck::index_t h = 0; h < d2; h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < d1 * d0; w++)
|
||||
{
|
||||
idxs_3d.emplace_back(ck::make_tuple(d, h), w);
|
||||
}
|
||||
}
|
||||
}
|
||||
this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d);
|
||||
// Check also 5d iteration
|
||||
std::vector<ck::Tuple<ck::Tuple<ck::Tuple<ck::index_t, ck::index_t>, ck::index_t>,
|
||||
ck::Tuple<ck::index_t, ck::index_t>>>
|
||||
idxs_5d;
|
||||
|
||||
for(ck::index_t f = 0; f < d4; f++)
|
||||
{
|
||||
for(ck::index_t e = 0; e < d3; e++)
|
||||
{
|
||||
for(ck::index_t d = 0; d < d2; d++)
|
||||
{
|
||||
for(ck::index_t h = 0; h < d1; h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < d0; w++)
|
||||
{
|
||||
idxs_5d.emplace_back(ck::make_tuple(ck::make_tuple(f, e), d),
|
||||
ck::make_tuple(h, w));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_5d);
|
||||
}
|
||||
|
||||
TEST(TestLayoutHelpers, SizeAndGet)
|
||||
{
|
||||
// dims:(((2, 2), 3), (4, 3))
|
||||
constexpr ck::index_t d4 = 2;
|
||||
constexpr ck::index_t d3 = 2;
|
||||
constexpr ck::index_t d2 = 3;
|
||||
constexpr ck::index_t d1 = 4;
|
||||
constexpr ck::index_t d0 = 3;
|
||||
const auto layout_runtime = ck::wrapper::make_layout(
|
||||
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)));
|
||||
const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
|
||||
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})));
|
||||
|
||||
// Size of layout
|
||||
EXPECT_EQ(ck::wrapper::size(layout_runtime), d4 * d3 * d2 * d1 * d0);
|
||||
EXPECT_EQ(ck::wrapper::size(layout_compiletime), d4 * d3 * d2 * d1 * d0);
|
||||
|
||||
// Size of dims
|
||||
EXPECT_EQ(ck::wrapper::size<0>(layout_runtime), d4 * d3 * d2);
|
||||
EXPECT_EQ(ck::wrapper::size<0>(layout_compiletime), d4 * d3 * d2);
|
||||
EXPECT_EQ(ck::wrapper::size<1>(layout_runtime), d1 * d0);
|
||||
EXPECT_EQ(ck::wrapper::size<1>(layout_compiletime), d1 * d0);
|
||||
|
||||
// Access through new layout (using get with layout object)
|
||||
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_runtime)), d4 * d3);
|
||||
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_compiletime)), d4 * d3);
|
||||
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2);
|
||||
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2);
|
||||
|
||||
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d4);
|
||||
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))),
|
||||
d4);
|
||||
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d3);
|
||||
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))),
|
||||
d3);
|
||||
|
||||
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2);
|
||||
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2);
|
||||
|
||||
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_runtime)), d1);
|
||||
EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_compiletime)), d1);
|
||||
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_runtime)), d0);
|
||||
EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_compiletime)), d0);
|
||||
}
|
||||
|
||||
TEST(TestLayoutHelpers, DepthAndRank)
|
||||
{
|
||||
// dims:(((2, 2), 3), (4, 3))
|
||||
constexpr ck::index_t d4 = 2;
|
||||
constexpr ck::index_t d3 = 2;
|
||||
constexpr ck::index_t d2 = 3;
|
||||
constexpr ck::index_t d1 = 4;
|
||||
constexpr ck::index_t d0 = 3;
|
||||
const auto layout_runtime = ck::wrapper::make_layout(
|
||||
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)));
|
||||
const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
|
||||
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})));
|
||||
|
||||
EXPECT_EQ(ck::wrapper::depth(layout_runtime), 3);
|
||||
EXPECT_EQ(ck::wrapper::depth(layout_compiletime), 3);
|
||||
EXPECT_EQ(ck::wrapper::depth(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2);
|
||||
// Check for integer
|
||||
EXPECT_EQ(ck::wrapper::depth(d0), 0);
|
||||
|
||||
EXPECT_EQ(ck::wrapper::rank(layout_runtime), 2);
|
||||
EXPECT_EQ(ck::wrapper::rank(layout_compiletime), 2);
|
||||
EXPECT_EQ(ck::wrapper::rank(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2);
|
||||
// Check for integer
|
||||
EXPECT_EQ(ck::wrapper::rank(d0), 1);
|
||||
}
|
||||
|
||||
TEST(TestLayoutHelpers, ShapeAndStrides)
|
||||
{
|
||||
// dims:(((2, 2), 3), (4, 3))
|
||||
constexpr ck::index_t d4 = 2;
|
||||
constexpr ck::index_t d3 = 2;
|
||||
constexpr ck::index_t d2 = 3;
|
||||
constexpr ck::index_t d1 = 4;
|
||||
constexpr ck::index_t d0 = 3;
|
||||
constexpr ck::index_t s4 = 2;
|
||||
constexpr ck::index_t s3 = 4;
|
||||
constexpr ck::index_t s2 = 8;
|
||||
constexpr ck::index_t s1 = 96;
|
||||
constexpr ck::index_t s0 = 24;
|
||||
const auto shape_compiletime = ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
|
||||
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{}));
|
||||
const auto strides_compiletime = ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<s4>{}, ck::Number<s3>{}), ck::Number<s2>{}),
|
||||
ck::make_tuple(ck::Number<s1>{}, ck::Number<s0>{}));
|
||||
const auto shape_runtime =
|
||||
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0));
|
||||
const auto strides_runtime =
|
||||
ck::make_tuple(ck::make_tuple(ck::make_tuple(s4, s3), s2), ck::make_tuple(s1, s0));
|
||||
const auto layout_runtime = ck::wrapper::make_layout(shape_runtime, strides_runtime);
|
||||
const auto layout_compiletime =
|
||||
ck::wrapper::make_layout(shape_compiletime, strides_compiletime);
|
||||
|
||||
constexpr bool check_compiletime_shape =
|
||||
std::is_same_v<std::remove_const<decltype(shape_compiletime)>::type,
|
||||
decltype(shape(layout_compiletime))>;
|
||||
constexpr bool check_compiletime_strides =
|
||||
std::is_same_v<std::remove_const<decltype(strides_compiletime)>::type,
|
||||
decltype(stride(layout_compiletime))>;
|
||||
constexpr bool check_runtime_shape =
|
||||
std::is_same_v<std::remove_const<decltype(shape_runtime)>::type,
|
||||
decltype(shape(layout_runtime))>;
|
||||
constexpr bool check_runtime_strides =
|
||||
std::is_same_v<std::remove_const<decltype(strides_runtime)>::type,
|
||||
decltype(stride(layout_runtime))>;
|
||||
EXPECT_TRUE(check_compiletime_shape);
|
||||
EXPECT_TRUE(check_compiletime_strides);
|
||||
EXPECT_TRUE(check_runtime_shape);
|
||||
EXPECT_TRUE(check_runtime_strides);
|
||||
}
|
||||
|
||||
TEST(TestLayoutHelpers, Hierarchical)
|
||||
{
|
||||
// dims:(((2, 2), 3), (4, 3))
|
||||
constexpr ck::index_t d4 = 2;
|
||||
constexpr ck::index_t d3 = 2;
|
||||
constexpr ck::index_t d2 = 3;
|
||||
constexpr ck::index_t d1 = 4;
|
||||
constexpr ck::index_t d0 = 3;
|
||||
const auto runtime_shape =
|
||||
ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0));
|
||||
const auto layout_runtime = ck::wrapper::make_layout(runtime_shape);
|
||||
const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<d4>{}, ck::Number<d3>{}), ck::Number<d2>{}),
|
||||
ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{})));
|
||||
|
||||
EXPECT_EQ((ck::wrapper::rank<0, 0>(runtime_shape)), 2);
|
||||
EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_runtime)), 2);
|
||||
EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_compiletime)), 2);
|
||||
|
||||
EXPECT_EQ((ck::wrapper::depth<0, 0>(runtime_shape)), 1);
|
||||
EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_runtime)), 1);
|
||||
EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_compiletime)), 1);
|
||||
|
||||
EXPECT_EQ((ck::wrapper::size<0, 0>(runtime_shape)), d4 * d3);
|
||||
EXPECT_EQ((ck::wrapper::size<0, 0>(layout_runtime)), d4 * d3);
|
||||
EXPECT_EQ((ck::wrapper::size<0, 0>(layout_compiletime)), d4 * d3);
|
||||
|
||||
EXPECT_EQ((ck::wrapper::get<0, 0, 0>(runtime_shape)), d4);
|
||||
}
|
||||
Reference in New Issue
Block a user