mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Introduce wrapper library (#1071)
* Introduce wrapper library
* Update cmake files
* Revert "Update cmake files"
This reverts commit c27f88b565.
* Fix comments
This commit is contained in:
4
client_example/25_tensor_transforms/CMakeLists.txt
Normal file
4
client_example/25_tensor_transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
add_executable(client_tensor_transform tensor_transform.cpp)
|
||||
target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
|
||||
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
|
||||
150
client_example/25_tensor_transforms/tensor_transform.cpp
Normal file
150
client_example/25_tensor_transforms/tensor_transform.cpp
Normal file
@@ -0,0 +1,150 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
|
||||
static constexpr auto I0 = ck::Number<0>{};
|
||||
static constexpr auto I1 = ck::Number<1>{};
|
||||
static constexpr auto I2 = ck::Number<2>{};
|
||||
|
||||
using DataType = int;
|
||||
|
||||
template <typename Desc>
|
||||
void Print1d(const Desc& desc)
|
||||
{
|
||||
std::cout << "Print1d" << std::endl;
|
||||
for(ck::index_t w = 0; w < desc.GetLength(I0); w++)
|
||||
{
|
||||
std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename Desc>
|
||||
void Print2d(const Desc& desc)
|
||||
{
|
||||
std::cout << "Print2d" << std::endl;
|
||||
for(ck::index_t h = 0; h < desc.GetLength(I0); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < desc.GetLength(I1); w++)
|
||||
{
|
||||
std::cout << desc.CalculateOffset(ck::make_tuple(h, w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Desc>
|
||||
void Print3dCustom(const Desc& desc)
|
||||
{
|
||||
std::cout << "Print3dCustom" << std::endl;
|
||||
for(ck::index_t d = 0; d < desc.GetLength(I0); d++)
|
||||
{
|
||||
for(ck::index_t h = 0; h < desc.GetLength(I1); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < desc.GetLength(I2); w++)
|
||||
{
|
||||
std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
// Tensor descriptor traverse in row-major (need to reverse dims)
|
||||
std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl;
|
||||
// Basic descriptor 0, 1, 2, ... 30, 31
|
||||
// (dims:4,8 strides:1,4)
|
||||
const auto desc_4x8_s1x4 =
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}),
|
||||
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}));
|
||||
std::cout << "dims:4,8 strides:1,4" << std::endl;
|
||||
Print2d(desc_4x8_s1x4);
|
||||
|
||||
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
|
||||
constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{});
|
||||
std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:4,(2,4) strides:2,(1,8)
|
||||
const auto desc_4x2x4_s2x1x8 =
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8));
|
||||
// Transform to 2d (column-major, need to to reverse dims)
|
||||
const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor(
|
||||
desc_4x2x4_s2x1x8,
|
||||
ck::make_tuple(ck::make_pass_through_transform(4),
|
||||
ck::make_merge_transform(ck::make_tuple(4, 2))),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
|
||||
|
||||
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
|
||||
Print2d(desc_4x2x4_s2x1x8_merged);
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:(2,2),(2,4) strides:((1,4),(2,8)
|
||||
const auto desc_2x2x2x4_s1x4x2x8 =
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
|
||||
// Transform to 2d
|
||||
const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
|
||||
ck::make_merge_transform(ck::make_tuple(4, 2))),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
|
||||
// Transform to 3d
|
||||
const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8,
|
||||
ck::make_tuple(ck::make_pass_through_transform(2),
|
||||
ck::make_pass_through_transform(2),
|
||||
ck::make_merge_transform(ck::make_tuple(4, 2))),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
|
||||
|
||||
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
|
||||
Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d);
|
||||
Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d);
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:((2,2),2),4 strides:((1,4),2),8
|
||||
// Transform to 2d
|
||||
const auto desc_2x2x2x4_s1x4x2x8_nested =
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
|
||||
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8_nested,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
|
||||
ck::make_pass_through_transform(2),
|
||||
ck::make_pass_through_transform(4)),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
|
||||
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8_nested,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))),
|
||||
ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}));
|
||||
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor(
|
||||
desc_2x2x2x4_s1x4x2x8_nested_merged_3d,
|
||||
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)),
|
||||
ck::make_pass_through_transform(4)),
|
||||
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}),
|
||||
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
|
||||
|
||||
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
|
||||
Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d);
|
||||
Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d);
|
||||
Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
|
||||
using DataType = int;
|
||||
|
||||
template <typename Layout>
|
||||
void Print1d(const Layout& layout)
|
||||
{
|
||||
std::cout << "Print1d" << std::endl;
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
void Print2d(const Layout& layout)
|
||||
{
|
||||
std::cout << "Print2d" << std::endl;
|
||||
for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(h, w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Print in (x,y),z pattern
|
||||
template <typename Layout>
|
||||
void Print3dCustom(const Layout& layout)
|
||||
{
|
||||
std::cout << "Print3dCustom" << std::endl;
|
||||
for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++)
|
||||
{
|
||||
for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
// Layout traverse in row-major
|
||||
std::cout << "Note: Layout traverse in column-major" << std::endl;
|
||||
// Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor)
|
||||
// (dims:4,8 strides:1,4)
|
||||
const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{});
|
||||
const auto layout_4x8_s1x4 = ck::wrapper::make_layout(shape_4x8);
|
||||
std::cout << "dims:4,8 strides:1,4" << std::endl;
|
||||
Print2d(layout_4x8_s1x4);
|
||||
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
|
||||
constexpr ck::index_t offset_1x1 = layout_4x8_s1x4.template operator()<Cord1x1Type>();
|
||||
std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor)
|
||||
// dims:4,(2,4) strides:2,(1,8)
|
||||
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
|
||||
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
|
||||
const auto layout_4x2x4_s2x1x8 = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
|
||||
|
||||
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
|
||||
Print2d(layout_4x2x4_s2x1x8);
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:(2,2),(2,4) strides:((1,4),(2,8)
|
||||
const auto shape_2x2x2x4 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}),
|
||||
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}));
|
||||
const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}),
|
||||
ck::make_tuple(ck::Number<2>{}, ck::Number<8>{}));
|
||||
static const auto layout_2x2x2x4_s1x4x2x8 =
|
||||
ck::wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8);
|
||||
|
||||
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
|
||||
Print2d(layout_2x2x2x4_s1x4x2x8);
|
||||
Print3dCustom(layout_2x2x2x4_s1x4x2x8);
|
||||
|
||||
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
|
||||
// dims:((2,2),2),4 strides:((1,4),2),8
|
||||
// Transform to 2d
|
||||
const auto shape_2x2x2x4_nested = ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<2>{}),
|
||||
ck::Number<4>{});
|
||||
const auto strides_s1x4x2x8_nested = ck::make_tuple(
|
||||
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}),
|
||||
ck::Number<8>{});
|
||||
static const auto layout_2x2x2x4_s1x4x2x8_nested =
|
||||
ck::wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested);
|
||||
|
||||
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
|
||||
Print1d(layout_2x2x2x4_s1x4x2x8_nested);
|
||||
Print2d(layout_2x2x2x4_s1x4x2x8_nested);
|
||||
Print3dCustom(layout_2x2x2x4_s1x4x2x8_nested);
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user