mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add tensor structure to wrapper (#1098)
* Add tensor structure to wrapper
* update changelog
* Fix names
* Comment fixes
[ROCm/composable_kernel commit: 07092d68f0]
This commit is contained in:
@@ -19,7 +19,7 @@ None
|
||||
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
|
||||
- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
|
||||
- Support for Batched Gemm DL (#732)
|
||||
- Introduce wrapper sublibrary (limited functionality) (#1071)
|
||||
- Introduce wrapper sublibrary (limited functionality). (#1071, #1098)
|
||||
|
||||
### Changes
|
||||
- Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
|
||||
|
||||
@@ -13,7 +13,7 @@ Description
|
||||
|
||||
CK provides a lightweight wrapper for more complex operations implemented in
|
||||
the library. It allows indexing of nested layouts using a simple interface
|
||||
(avoiding complex descriptor transformations).
|
||||
(avoiding complex descriptor transformations) and memory access (using Tensor).
|
||||
|
||||
Example:
|
||||
|
||||
@@ -22,24 +22,31 @@ Example:
|
||||
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
|
||||
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
|
||||
const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
|
||||
|
||||
std::array<ck::index_t, 32> data;
|
||||
auto tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(&data[0], layout);
|
||||
|
||||
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
|
||||
for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++)
|
||||
for(ck::index_t w = 0; w < size(tensor); w++) {
|
||||
tensor(w) = w;
|
||||
}
|
||||
|
||||
// slice() == slice(0, -1) (whole dimension)
|
||||
auto tensor_slice = tensor(ck::wrapper::slice(1, 3), ck::make_tuple(ck::wrapper::slice(), ck::wrapper::slice()));
|
||||
std::cout << "dims:2,(2,4) strides:2,(1,8)" << std::endl;
|
||||
for(ck::index_t h = 0; h < ck::wrapper::size<0>(tensor_slice); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size<1>(tensor_slice); w++)
|
||||
{
|
||||
std::cout << layout(ck::make_tuple(h, w)) << " ";
|
||||
std::cout << tensor_slice(h, w) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
Output::
|
||||
|
||||
dims:4,(2,4) strides:2,(1,8)
|
||||
0 1 8 9 16 17 24 25
|
||||
2 3 10 11 18 19 26 27
|
||||
4 5 12 13 20 21 28 29
|
||||
6 7 14 15 22 23 30 31
|
||||
dims:2,(2,4) strides:2,(1,8)
|
||||
1 5 9 13 17 21 25 29
|
||||
2 6 10 14 18 22 26 30
|
||||
|
||||
-------------------------------------
|
||||
Layout
|
||||
@@ -52,3 +59,15 @@ Layout helpers
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenfile:: layout_utils.hpp
|
||||
|
||||
-------------------------------------
|
||||
Tensor
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenstruct:: ck::wrapper::Tensor
|
||||
|
||||
-------------------------------------
|
||||
Tensor helpers
|
||||
-------------------------------------
|
||||
|
||||
.. doxygenfile:: tensor_utils.hpp
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/wrapper/layout_utils.hpp"
|
||||
#include "ck/wrapper/utils/layout_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
@@ -25,6 +25,26 @@ struct Layout
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
// Generate default idxs tuple (idx with all merged nested shapes)
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple<Ts...>&)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto) {
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
// runtime layout
|
||||
return index_t(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// compiletime layout
|
||||
return I0;
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
}
|
||||
|
||||
// Generate packed (column-major) strides if not passed
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
@@ -131,7 +151,7 @@ struct Layout
|
||||
|
||||
template <typename... ShapeDims, typename DescriptorToMerge>
|
||||
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
|
||||
DescriptorToMerge& desc)
|
||||
const DescriptorToMerge& desc)
|
||||
{
|
||||
// Reverse each element in tuple
|
||||
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
|
||||
@@ -144,7 +164,7 @@ struct Layout
|
||||
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
// Merge nested shape dims. Merge nested shape dims when idx is also nested.
|
||||
// Merge nested shape dims when corresponding index is also nested.
|
||||
// Input desc shape: 2, 2, 2, 2, 2, 2
|
||||
// Example idx: 1, 1, 1, 1
|
||||
// Example shape: 2, (2, 2), 2, (2, 2)
|
||||
@@ -187,14 +207,38 @@ struct Layout
|
||||
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
|
||||
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
|
||||
using DeducedStrides =
|
||||
std::conditional_t<is_same_v<Strides, Tuple<>>,
|
||||
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
|
||||
Strides>;
|
||||
using FlattenDescriptorType =
|
||||
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
|
||||
using Descriptor1dType =
|
||||
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>;
|
||||
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
|
||||
|
||||
template <typename... ShapeDims, typename... IdxDims>
|
||||
__host__ __device__ constexpr auto TransformDesc(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx) const
|
||||
__host__ __device__ constexpr static auto
|
||||
TransformDesc(const Tuple<ShapeDims...>& shape,
|
||||
const Tuple<IdxDims...>& idx,
|
||||
const FlattenDescriptorType& naive_descriptor)
|
||||
{
|
||||
if constexpr(Tuple<IdxDims...>::Size() == I1)
|
||||
{
|
||||
// 1d idx path
|
||||
return MakeMerge1d(shape, descriptor_);
|
||||
return MakeMerge1d(shape, naive_descriptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -207,56 +251,53 @@ struct Layout
|
||||
// Unroll while IdxDims is nested
|
||||
const auto aligned_shape = AlignShapeToIdx(shape, idx);
|
||||
// Transform correct form of shape
|
||||
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_);
|
||||
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
const auto unrolled_strides = UnrollNestedTuple(strides);
|
||||
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
|
||||
"Size of strides and shape are not consistent.");
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
|
||||
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>;
|
||||
|
||||
public:
|
||||
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
|
||||
using DeducedStrides =
|
||||
std::conditional_t<is_same_v<Strides, Tuple<>>,
|
||||
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
|
||||
Strides>;
|
||||
using NaiveDescriptorType =
|
||||
remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, DeducedStrides{}))>;
|
||||
__host__ __device__ constexpr auto GetElementSpaceSize() const
|
||||
{
|
||||
return flatten_descriptor_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
__host__ __device__ Layout() = delete;
|
||||
/**
|
||||
* \brief Layout constructor.
|
||||
*
|
||||
* \param shape Shape for layout.
|
||||
* \param strides Strides for layout (optional if tensor is packed).
|
||||
* \return Layout object.
|
||||
*/
|
||||
__host__ __device__ Layout() = delete;
|
||||
__host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{}
|
||||
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides)
|
||||
: flatten_descriptor_{}, shape_(shape), strides_(strides)
|
||||
{
|
||||
// Construct if runtime mode
|
||||
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
shape_ = shape;
|
||||
strides_ = strides;
|
||||
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
|
||||
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
|
||||
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
|
||||
merged_nests_descriptor_ =
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ Layout(const Shape& shape) : descriptor_{}
|
||||
/**
|
||||
* \brief Layout constructor (with default packed column-major strides).
|
||||
*
|
||||
* \param shape Shape for layout.
|
||||
*/
|
||||
__host__ __device__ constexpr Layout(const Shape& shape)
|
||||
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
|
||||
{
|
||||
if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
|
||||
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
|
||||
{
|
||||
shape_ = shape;
|
||||
strides_ = GenerateColumnMajorPackedStrides(shape_);
|
||||
descriptor_ = MakeNaiveDescriptor(shape_, strides_);
|
||||
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
|
||||
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
|
||||
merged_nests_descriptor_ =
|
||||
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,7 +310,9 @@ struct Layout
|
||||
template <typename Idxs>
|
||||
__host__ __device__ constexpr index_t operator()() const
|
||||
{
|
||||
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
|
||||
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(),
|
||||
"Compiletime operator used on runtime layout.");
|
||||
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{}));
|
||||
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
|
||||
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
|
||||
}
|
||||
@@ -283,9 +326,22 @@ struct Layout
|
||||
template <typename... Ts>
|
||||
__host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
|
||||
{
|
||||
// Static to construct transformed_desc only once
|
||||
static const auto transformed_desc = TransformDesc(shape_, Idx);
|
||||
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
|
||||
if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
|
||||
{
|
||||
// if 1d access
|
||||
return descriptor_1d_.CalculateOffset(Idx);
|
||||
}
|
||||
else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
|
||||
{
|
||||
// if Shape::Size() access (merged nested shapes)
|
||||
return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Custom index, need to transform descriptor
|
||||
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_);
|
||||
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -327,19 +383,51 @@ struct Layout
|
||||
*
|
||||
* \return Shape.
|
||||
*/
|
||||
__host__ __device__ constexpr Shape GetShape() const { return shape_; }
|
||||
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
|
||||
|
||||
/**
|
||||
* \brief Strides getter.
|
||||
*
|
||||
* \return Strides.
|
||||
*/
|
||||
__host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; }
|
||||
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }
|
||||
|
||||
/**
|
||||
* \brief Get default lengths (tuple filled with Shape length elements).
|
||||
*
|
||||
* \return Default lengths.
|
||||
*/
|
||||
__host__ __device__ constexpr auto GetDefaultLengthsTuple() const
|
||||
{
|
||||
return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get default start idx (tuple filled with 0s of the same size as Shape).
|
||||
*
|
||||
* \return Default start idx.
|
||||
*/
|
||||
__host__ __device__ constexpr auto GetDefaultStartIdxs() const
|
||||
{
|
||||
return GenerateDefaultIdxsTuple(shape_);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get default descriptor (with the same size as Shape)
|
||||
*
|
||||
* \return Default descriptor.
|
||||
*/
|
||||
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor()
|
||||
{
|
||||
return merged_nests_descriptor_;
|
||||
}
|
||||
|
||||
private:
|
||||
NaiveDescriptorType descriptor_;
|
||||
Shape shape_;
|
||||
DeducedStrides strides_;
|
||||
FlattenDescriptorType flatten_descriptor_;
|
||||
Descriptor1dType descriptor_1d_;
|
||||
MergedNestsDescriptorType merged_nests_descriptor_;
|
||||
const Shape shape_;
|
||||
const DeducedStrides strides_;
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
314
include/ck/wrapper/tensor.hpp
Normal file
314
include/ck/wrapper/tensor.hpp
Normal file
@@ -0,0 +1,314 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "utils/tensor_utils.hpp"
|
||||
#include "utils/layout_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
/**
|
||||
* \brief Tensor wrapper that performs static and dynamic buffer logic.
|
||||
*
|
||||
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
|
||||
* \tparam ElementType Element data type.
|
||||
* \tparam Shape Tensor shape (layout component).
|
||||
* \tparam Strides Tensor strides (layout component).
|
||||
* \tparam NumVectors Number of vectors (only for VGPR, SGPR).
|
||||
* \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR).
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors, // param for Register memory
|
||||
index_t ScalarPerVector // param for Register memory
|
||||
>
|
||||
struct Tensor
|
||||
{
|
||||
private:
|
||||
// Check if Tuple contains Slice object
|
||||
template <typename T>
|
||||
constexpr static bool IsSlicing(T&&)
|
||||
{
|
||||
return is_detected<is_slice, T>::value;
|
||||
}
|
||||
template <typename... Ts>
|
||||
constexpr static bool IsSlicing(Tuple<Ts...>&&)
|
||||
{
|
||||
return (IsSlicing(Ts{}) || ...);
|
||||
}
|
||||
|
||||
// Calculate first index of new tensor after slice
|
||||
// It is needed to calculate offset for new tensor
|
||||
template <typename... Ts>
|
||||
constexpr auto GetStartIdxForSlicedTensor(const Tuple<Ts...>& idx) const
|
||||
{
|
||||
const auto start_idx_for_sliced_tensor = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// if tuple then recurrence
|
||||
return GetStartIdxForSlicedTensor(idx.At(num_i));
|
||||
}
|
||||
else if constexpr(is_detected<is_slice,
|
||||
tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// if slice, return the beginning of the interval
|
||||
return idx.At(num_i).from_;
|
||||
}
|
||||
else
|
||||
{
|
||||
// if one dim selected
|
||||
return idx.At(num_i);
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
|
||||
return start_idx_for_sliced_tensor;
|
||||
}
|
||||
|
||||
// Calculate new tensor shape after slice
|
||||
template <typename... Ts, typename ShapeTmpType>
|
||||
constexpr auto GetShapeFromSlicedTensor(const Tuple<Ts...>& idx,
|
||||
const ShapeTmpType& shape) const
|
||||
{
|
||||
// Pack each value in tuple to remove empty tuples after generation
|
||||
auto new_shape = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
if constexpr(!IsSlicing(tuple_element_t<i.value, Tuple<Ts...>>{}))
|
||||
{
|
||||
// if tuple does not have any slice then we can remove dimension
|
||||
return Tuple<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
// if tuple then recurrence
|
||||
return make_tuple(GetShapeFromSlicedTensor(idx.At(num_i), shape.At(num_i)));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_detected<is_slice,
|
||||
tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// calculate new dimension
|
||||
const auto& dim = size(shape.At(num_i));
|
||||
const auto val = idx.At(num_i).range(dim);
|
||||
return make_tuple(val);
|
||||
}
|
||||
else
|
||||
{
|
||||
// remove dimension for just value
|
||||
return Tuple<>{};
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
// Remove empty tuples (deleted elements) and return
|
||||
return UnrollNestedTuple<0, 1>(new_shape);
|
||||
}
|
||||
|
||||
template <typename... Ts, typename StridesTmpType>
|
||||
constexpr auto GetStridesFromSlicedTensor(const Tuple<Ts...>& idx,
|
||||
const StridesTmpType& strides) const
|
||||
{
|
||||
// Pack each value in tuple to remove empty tuples after generation
|
||||
auto new_strides = generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
if constexpr(!IsSlicing(tuple_element_t<i.value, Tuple<Ts...>>{}))
|
||||
{
|
||||
// if tuple does not have any slice then we can remove dimension
|
||||
return Tuple<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
// if tuple then recurrence
|
||||
return make_tuple(
|
||||
GetStridesFromSlicedTensor(idx.At(num_i), strides.At(num_i)));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_detected<is_slice,
|
||||
tuple_element_t<i.value, Tuple<Ts...>>>::value)
|
||||
{
|
||||
// Stride will be the same
|
||||
return make_tuple(strides.At(num_i));
|
||||
}
|
||||
else
|
||||
{
|
||||
// remove dimension for just value
|
||||
return Tuple<>{};
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ts...>::Size()>{});
|
||||
// Remove empty tuples (deleted elements) and return
|
||||
return UnrollNestedTuple<0, 1>(new_strides);
|
||||
}
|
||||
|
||||
public:
|
||||
using ElementSpaceSize = decltype(Layout<Shape, Strides>{
|
||||
Shape{}, Strides{}}.GetElementSpaceSize()); // SpaceSize type for buffer
|
||||
using TensorElementType = ElementType; // DataType
|
||||
|
||||
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
|
||||
static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
|
||||
BufferAddressSpace == MemoryTypeEnum ::Vgpr);
|
||||
|
||||
__host__ __device__ Tensor() = delete;
|
||||
__host__ __device__ Tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
|
||||
: layout_(layout),
|
||||
buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize()))
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ Tensor(const Layout<Shape, Strides>& layout) : layout_(layout)
|
||||
{
|
||||
static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr const Layout<Shape, Strides>& GetLayout() const
|
||||
{
|
||||
return layout_;
|
||||
}
|
||||
|
||||
// Getter for new sliced tensor
|
||||
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ auto operator[](const Tuple<Ts...>& idx) const
|
||||
{
|
||||
static_assert(IsDynamicBuffer, "Register slice is not supported");
|
||||
// Calculate offset based on first idx for new tensor
|
||||
const index_t offset = layout_(GetStartIdxForSlicedTensor(idx));
|
||||
|
||||
auto new_shape = GetShapeFromSlicedTensor(idx, layout_.GetShape());
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
auto new_layout = make_layout(new_shape);
|
||||
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto new_strides = GetStridesFromSlicedTensor(idx, layout_.GetStrides());
|
||||
auto new_layout = make_layout(new_shape, new_strides);
|
||||
return make_tensor<BufferAddressSpace>(buffer_.p_data_ + offset, new_layout);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ auto operator()(const Tuple<Ts...>& idx) const
|
||||
{
|
||||
return this->operator[](idx);
|
||||
}
|
||||
|
||||
template <typename... Idxs, enable_if_t<IsSlicing(Tuple<Idxs...>{}), bool> = false>
|
||||
__host__ __device__ auto operator()(Idxs... idxs) const
|
||||
{
|
||||
return this->operator[](make_tuple(idxs...));
|
||||
}
|
||||
|
||||
// Getter for the const value
|
||||
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const
|
||||
{
|
||||
if constexpr(IsDynamicBuffer)
|
||||
{
|
||||
const index_t offset = layout_(idx);
|
||||
return buffer_[offset];
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
constexpr index_t offset =
|
||||
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_[Number<offset>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t offset =
|
||||
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_[Number<offset>{}];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const
|
||||
{
|
||||
return this->operator[](idx);
|
||||
}
|
||||
|
||||
template <typename... Idxs, enable_if_t<!IsSlicing(Tuple<Idxs...>{}), bool> = false>
|
||||
__host__ __device__ const ElementType& operator()(Idxs... idxs) const
|
||||
{
|
||||
return this->operator[](make_tuple(idxs...));
|
||||
}
|
||||
|
||||
// Getter for the value reference
|
||||
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx)
|
||||
{
|
||||
if constexpr(IsDynamicBuffer)
|
||||
{
|
||||
const index_t offset = layout_(idx);
|
||||
return buffer_(offset);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
{
|
||||
constexpr index_t offset =
|
||||
Layout<Shape, Strides>{Shape{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_(Number<offset>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t offset =
|
||||
Layout<Shape, Strides>{Shape{}, Strides{}}.template operator()<Tuple<Ts...>>();
|
||||
return buffer_(Number<offset>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
|
||||
__host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
|
||||
{
|
||||
return this->operator[](idx);
|
||||
}
|
||||
|
||||
template <typename... Idxs, enable_if_t<!IsSlicing(Tuple<Idxs...>{}), bool> = false>
|
||||
__host__ __device__ ElementType& operator()(Idxs... idxs)
|
||||
{
|
||||
return this->operator[](make_tuple(idxs...));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDefaultDescriptor()
|
||||
{
|
||||
return layout_.GetDefaultDescriptor();
|
||||
}
|
||||
|
||||
private:
|
||||
using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
|
||||
ElementType,
|
||||
ElementSpaceSize,
|
||||
true /*InvalidElementUseNumericalZeroValue*/>;
|
||||
using StaticBufferType =
|
||||
StaticBufferTupleOfVector<BufferAddressSpace,
|
||||
ElementType,
|
||||
NumVectors,
|
||||
ScalarPerVector,
|
||||
true /*InvalidElementUseNumericalZeroValue*/>;
|
||||
// If register use static buffer, else use dynamic buffer
|
||||
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
|
||||
|
||||
const Layout<Shape, Strides> layout_;
|
||||
Buffer buffer_;
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -22,7 +22,7 @@ namespace wrapper {
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declaration
|
||||
template <typename Shape, typename Strides = Tuple<>>
|
||||
template <typename Shape, typename Strides>
|
||||
struct Layout;
|
||||
|
||||
template <typename T>
|
||||
@@ -52,13 +52,23 @@ __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& sh
|
||||
* \return Constructed layout.
|
||||
*/
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr Layout<Shape> make_layout(const Shape& shape)
|
||||
__host__ __device__ constexpr Layout<Shape, Tuple<>> make_layout(const Shape& shape)
|
||||
{
|
||||
return Layout<Shape>(shape);
|
||||
return Layout<Shape, Tuple<>>(shape);
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
// get
|
||||
// Get dim (could be returned from get with empty Idxs)
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr get(const T& dim)
|
||||
{
|
||||
return dim;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get element from tuple (Shape/Strides/Idxs).
|
||||
*
|
||||
@@ -82,7 +92,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
const auto new_shape = get<idx>(layout.GetShape());
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto& new_shape = get<idx>(shape);
|
||||
static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
|
||||
"Shape of sub layout must be tuple");
|
||||
if constexpr(is_same_v<Strides, Tuple<>>)
|
||||
@@ -92,7 +103,8 @@ __host__ __device__ constexpr auto get(const Layout<Shape, Strides>& layout)
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto new_strides = get<idx>(layout.GetStrides());
|
||||
const auto& strides = layout.GetStrides();
|
||||
const auto& new_strides = get<idx>(strides);
|
||||
static_assert(is_detected<is_tuple, decltype(new_strides)>::value,
|
||||
"Strides of sub layout must be tuple");
|
||||
return make_layout(new_shape, new_strides);
|
||||
@@ -113,11 +125,21 @@ __host__ __device__ constexpr auto get(const T& elem)
|
||||
}
|
||||
|
||||
// size
|
||||
// Get dim size (could be returned from get function)
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr size(const T& dim)
|
||||
{
|
||||
return dim;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Length get (product if tuple).
|
||||
*
|
||||
* \tparam idx Index to lookup.
|
||||
* \param layout Layout to get Shape.
|
||||
* \param layout Layout to get Shape of.
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename Strides>
|
||||
@@ -140,16 +162,6 @@ __host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
|
||||
unrolled_shape);
|
||||
}
|
||||
|
||||
// Get dim size (could be returned from get function)
|
||||
/**
|
||||
* \private
|
||||
*/
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr size(const T& dim)
|
||||
{
|
||||
return dim;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Layout size (product of dims).
|
||||
*
|
||||
@@ -178,14 +190,15 @@ __host__ __device__ constexpr index_t size(const Tuple<Ts...>& tuple)
|
||||
/**
|
||||
* \brief Hierarchical size.
|
||||
*
|
||||
* \tparam Idxs Indexes to lookup.
|
||||
* \tparam Idx First index to lookup (to avoid empty Idxs).
|
||||
* \tparam Idxs Next indexes to lookup.
|
||||
* \param elem Element to lookup.
|
||||
* \return Requsted element.
|
||||
*/
|
||||
template <index_t... Idxs, typename T>
|
||||
template <index_t Idx, index_t... Idxs, typename T>
|
||||
__host__ __device__ constexpr auto size(const T& elem)
|
||||
{
|
||||
return size(get<Idxs...>(elem));
|
||||
return size(get<Idx, Idxs...>(elem));
|
||||
}
|
||||
|
||||
// rank
|
||||
@@ -251,7 +264,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return TupleDepth(layout.GetShape());
|
||||
const auto& shape = layout.GetShape();
|
||||
return TupleDepth(shape);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -296,11 +310,11 @@ __host__ __device__ constexpr auto depth(const T& elem)
|
||||
/**
|
||||
* \brief Get Layout strides.
|
||||
*
|
||||
* \param layout Layout to get strides.
|
||||
* \param layout Layout to get strides from.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto stride(const Layout<Shape, Strides>& layout)
|
||||
__host__ __device__ constexpr const auto& stride(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetStrides();
|
||||
}
|
||||
@@ -308,11 +322,11 @@ __host__ __device__ constexpr auto stride(const Layout<Shape, Strides>& layout)
|
||||
/**
|
||||
* \brief Get Layout shape.
|
||||
*
|
||||
* \param layout Layout to get shape.
|
||||
* \param layout Layout to get shape from.
|
||||
* \return Requsted shape.
|
||||
*/
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto shape(const Layout<Shape, Strides>& layout)
|
||||
__host__ __device__ constexpr const auto& shape(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return layout.GetShape();
|
||||
}
|
||||
290
include/ck/wrapper/utils/tensor_utils.hpp
Normal file
290
include/ck/wrapper/utils/tensor_utils.hpp
Normal file
@@ -0,0 +1,290 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
#include "ck/utility/dynamic_buffer.hpp"
|
||||
#include "ck/utility/amd_address_space.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
/**
|
||||
* \brief Memory type, allowed members:
|
||||
* - Generic,
|
||||
* - Global,
|
||||
* - LDS,
|
||||
* - SGPR,
|
||||
* - VGPR,
|
||||
*/
|
||||
using MemoryTypeEnum = AddressSpaceEnum;
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declarations
|
||||
template <typename Shape, typename Strides>
|
||||
struct Layout;
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors, // params for Register memory
|
||||
index_t ScalarPerVector // param for Register memory
|
||||
>
|
||||
|
||||
struct Tensor;
|
||||
|
||||
template <typename FromType, typename ToType>
|
||||
struct Slice
|
||||
{
|
||||
__host__ __device__ constexpr Slice() : from_(), to_() {}
|
||||
__host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto range(const T& dim) const
|
||||
{
|
||||
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
|
||||
is_same_v<T, index_t>)
|
||||
{
|
||||
assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range");
|
||||
if(to_ < 0)
|
||||
{
|
||||
return dim - from_ + to_ + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// workaround if one end of the interval is index_t and the second one is Number
|
||||
return static_cast<index_t>(to_) - static_cast<index_t>(from_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_),
|
||||
"Invalid range");
|
||||
if constexpr(to_ < 0)
|
||||
{
|
||||
return dim - from_ + to_ + Number<1>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return to_ - from_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsSlice() { return true; }
|
||||
|
||||
const FromType from_;
|
||||
const ToType to_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using is_slice = decltype(std::declval<T&>().IsSlice());
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Make tensor function.
|
||||
*
|
||||
* \tparam MemoryType Type of memory.
|
||||
* \param pointer Pointer to the memory.
|
||||
* \param layout Tensor layout.
|
||||
* \return Constructed tensor.
|
||||
*/
|
||||
template <MemoryTypeEnum MemoryType, typename ElementType, typename Shape, typename Strides>
|
||||
constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
return Tensor<MemoryType, ElementType, Shape, Strides, 0 /*NumVectors*/, 0 /*ScalarPerVector*/>(
|
||||
pointer, layout);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Make SGPR or VGPR tensor function.
|
||||
*
|
||||
* \tparam MemoryType Type of memory.
|
||||
* \tparam NumVectors Number of vectors.
|
||||
* \tparam ScalarPerVector Scalars per vector.
|
||||
* \tparam ElementType Memory data type.
|
||||
* \param layout Tensor layout.
|
||||
* \return Constructed tensor.
|
||||
*/
|
||||
template <MemoryTypeEnum MemoryType,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides>
|
||||
constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
|
||||
{
|
||||
static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported");
|
||||
return Tensor<MemoryType, ElementType, Shape, Strides, NumVectors, ScalarPerVector>(layout);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor Layout.
|
||||
*
|
||||
* \param tensor Tensor to get layout of.
|
||||
* \return Requsted layout.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
layout(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return tensor.GetLayout();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Product of tensor shape dims.
|
||||
*
|
||||
* \tparam Idxs Indexes to access specific shape dim (optional).
|
||||
* \param tensor Tensor to get Shape of.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
size(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return size<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Rank of Shape tuple.
|
||||
*
|
||||
* \tparam Idxs Indexes to access specific shape dim (optional).
|
||||
* \param tensor Tensor to get rank of.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
rank(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return rank<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Depth of Shape tuple.
|
||||
*
|
||||
* \tparam Idxs Indexes to access specific shape dim (optional).
|
||||
* \param tensor Tensor to get depth of.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <index_t... Idxs,
|
||||
MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr index_t
|
||||
depth(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return depth<Idxs...>(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor strides.
|
||||
*
|
||||
* \param tensor Tensor to get strides from.
|
||||
* \return Requsted strides.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return stride(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor shape.
|
||||
*
|
||||
* \param tensor Tensor to get shape from.
|
||||
* \return Requsted shape.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename Strides,
|
||||
index_t NumVectors,
|
||||
index_t ScalarPerVector>
|
||||
__host__ __device__ constexpr const auto&
|
||||
shape(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
|
||||
tensor)
|
||||
{
|
||||
return shape(tensor.GetLayout());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get dim slice.
|
||||
*
|
||||
* \param from Beginning of the interval.
|
||||
* \param to End of the interval. (could be also negative to index from the end)
|
||||
* \return Requested slice. Could be used to create sliced tensor from other tensor.
|
||||
*/
|
||||
template <typename FromType, typename ToType>
|
||||
constexpr auto slice(const FromType from, const ToType to)
|
||||
{
|
||||
return Slice<FromType, ToType>(from, to);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get dim slice. (Assumed that from is equal to 1)
|
||||
*
|
||||
* \param to End of the interval. (could be also negative to index from the end)
|
||||
* \return Requested slice. Could be used to create sliced tensor from other tensor.
|
||||
*/
|
||||
template <typename ToType>
|
||||
constexpr auto slice(const ToType to)
|
||||
{
|
||||
if constexpr(is_same_v<ToType, index_t>)
|
||||
{
|
||||
return Slice<index_t, ToType>(0, to);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Slice<Number<0>, ToType>(Number<0>{}, to);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get whole dim slice (from = 0, to = -1).
|
||||
*
|
||||
* \return Requested slice. Could be used to create sliced tensor from other tensor.
|
||||
*/
|
||||
constexpr auto slice() { return Slice<Number<0>, Number<-1>>(Number<0>{}, Number<-1>{}); }
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -1,2 +1,4 @@
|
||||
add_gtest_executable(test_layout test_layout.cpp)
|
||||
target_link_libraries(test_layout PRIVATE utility)
|
||||
add_gtest_executable(test_tensor test_tensor.cpp)
|
||||
target_link_libraries(test_tensor PRIVATE utility)
|
||||
|
||||
@@ -433,17 +433,17 @@ TEST(TestLayoutHelpers, ShapeAndStrides)
|
||||
ck::wrapper::make_layout(shape_compiletime, strides_compiletime);
|
||||
|
||||
constexpr bool check_compiletime_shape =
|
||||
std::is_same_v<std::remove_const<decltype(shape_compiletime)>::type,
|
||||
decltype(shape(layout_compiletime))>;
|
||||
std::is_same_v<decltype(shape_compiletime),
|
||||
std::remove_reference_t<decltype(shape(layout_compiletime))>>;
|
||||
constexpr bool check_compiletime_strides =
|
||||
std::is_same_v<std::remove_const<decltype(strides_compiletime)>::type,
|
||||
decltype(stride(layout_compiletime))>;
|
||||
std::is_same_v<decltype(strides_compiletime),
|
||||
std::remove_reference_t<decltype(stride(layout_compiletime))>>;
|
||||
constexpr bool check_runtime_shape =
|
||||
std::is_same_v<std::remove_const<decltype(shape_runtime)>::type,
|
||||
decltype(shape(layout_runtime))>;
|
||||
std::is_same_v<decltype(shape_runtime),
|
||||
std::remove_reference_t<decltype(shape(layout_runtime))>>;
|
||||
constexpr bool check_runtime_strides =
|
||||
std::is_same_v<std::remove_const<decltype(strides_runtime)>::type,
|
||||
decltype(stride(layout_runtime))>;
|
||||
std::is_same_v<decltype(strides_runtime),
|
||||
std::remove_reference_t<decltype(stride(layout_runtime))>>;
|
||||
EXPECT_TRUE(check_compiletime_shape);
|
||||
EXPECT_TRUE(check_compiletime_strides);
|
||||
EXPECT_TRUE(check_runtime_shape);
|
||||
|
||||
205
test/wrapper/test_tensor.cpp
Normal file
205
test/wrapper/test_tensor.cpp
Normal file
@@ -0,0 +1,205 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
|
||||
// Compare data in tensor with offset from layout.
|
||||
// Data and offset should match if physical memory has been initialized with
|
||||
// sequentially increasing values from 0.
|
||||
template <typename TensorType>
|
||||
__host__ __device__ bool TestTensorCheck3d(TensorType& tensor)
|
||||
{
|
||||
const auto& layout = ck::wrapper::layout(tensor);
|
||||
for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++)
|
||||
{
|
||||
for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++)
|
||||
{
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
|
||||
{
|
||||
const auto idx = ck::make_tuple(ck::make_tuple(d, h), w);
|
||||
if(tensor(idx) != layout(idx))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TensorType>
|
||||
__host__ __device__ bool TestTensorCheck1d(TensorType& tensor, ck::index_t start_offset = 0)
|
||||
{
|
||||
const auto& layout = ck::wrapper::layout(tensor);
|
||||
for(ck::index_t w = 0; w < ck::wrapper::size<0>(layout); w++)
|
||||
{
|
||||
if(tensor(w) - start_offset != layout(ck::make_tuple(w)))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <ck::index_t nelems, typename TensorType>
|
||||
__host__ __device__ bool StaticTestTensorCheck1d(TensorType& tensor)
|
||||
{
|
||||
const auto& layout = ck::wrapper::layout(tensor);
|
||||
bool success = true;
|
||||
ck::static_for<0, nelems, 1>{}([&](auto w) {
|
||||
if(tensor(ck::Number<w.value>{}) != layout(ck::make_tuple(w.value)))
|
||||
{
|
||||
success = false;
|
||||
}
|
||||
});
|
||||
return success;
|
||||
}
|
||||
|
||||
template <typename TensorType>
|
||||
__host__ __device__ void InitTensor(TensorType& tensor)
|
||||
{
|
||||
for(ck::index_t i = 0; i < ck::wrapper::size(ck::wrapper::layout(tensor)); i++)
|
||||
{
|
||||
tensor(i) = i;
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t nelems, typename TensorType>
|
||||
__host__ __device__ void StaticInitTensor(TensorType& tensor)
|
||||
{
|
||||
|
||||
ck::static_for<0, nelems, 1>{}([&](auto i) { tensor(ck::Number<i.value>{}) = i.value; });
|
||||
}
|
||||
|
||||
// Tests
|
||||
TEST(TestTensor, ReadWriteHostMemory)
|
||||
{
|
||||
constexpr ck::index_t nelems = 8;
|
||||
|
||||
std::array<ck::index_t, nelems> data;
|
||||
const auto layout = ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(2, 2), 2));
|
||||
auto tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(&data[0], layout);
|
||||
InitTensor(tensor);
|
||||
|
||||
EXPECT_TRUE(TestTensorCheck1d(tensor));
|
||||
EXPECT_TRUE(TestTensorCheck3d(tensor));
|
||||
}
|
||||
|
||||
__global__ void TestTensorReadWriteDevice(void* data, void* success)
|
||||
{
|
||||
constexpr ck::index_t nelems = 8;
|
||||
constexpr ck::index_t scalar_per_vector = 1;
|
||||
__shared__ ck::index_t p_shared[nelems];
|
||||
|
||||
ck::index_t* casted_data_ptr = static_cast<ck::index_t*>(data);
|
||||
bool* casted_success_ptr = static_cast<bool*>(success);
|
||||
|
||||
const auto layout = ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(2, 2), 2));
|
||||
constexpr auto register_layout = ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}));
|
||||
|
||||
auto tensor_global =
|
||||
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(casted_data_ptr, layout);
|
||||
auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(p_shared, layout);
|
||||
auto tensor_vgpr = ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr,
|
||||
nelems,
|
||||
scalar_per_vector,
|
||||
ck::index_t>(register_layout);
|
||||
auto tensor_sgpr = ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Sgpr,
|
||||
nelems,
|
||||
scalar_per_vector,
|
||||
ck::index_t>(register_layout);
|
||||
|
||||
InitTensor(tensor_global);
|
||||
InitTensor(tensor_lds);
|
||||
StaticInitTensor<nelems>(tensor_vgpr);
|
||||
StaticInitTensor<nelems>(tensor_sgpr);
|
||||
|
||||
*casted_success_ptr &= TestTensorCheck1d(tensor_global);
|
||||
*casted_success_ptr &= TestTensorCheck3d(tensor_global);
|
||||
|
||||
*casted_success_ptr &= TestTensorCheck1d(tensor_lds);
|
||||
*casted_success_ptr &= TestTensorCheck3d(tensor_lds);
|
||||
|
||||
*casted_success_ptr &= StaticTestTensorCheck1d<nelems>(tensor_vgpr);
|
||||
|
||||
*casted_success_ptr &= StaticTestTensorCheck1d<nelems>(tensor_sgpr);
|
||||
}
|
||||
|
||||
TEST(TestTensor, ReadWriteGlobalLdsRegistersMemory)
|
||||
{
|
||||
constexpr ck::index_t nelems = 8;
|
||||
std::array<ck::index_t, nelems> host_data;
|
||||
|
||||
DeviceMem data_buf(nelems * sizeof(ck::index_t));
|
||||
data_buf.ToDevice(&host_data[0]);
|
||||
DeviceMem success_buf(sizeof(bool));
|
||||
|
||||
launch_and_time_kernel(StreamConfig{},
|
||||
TestTensorReadWriteDevice,
|
||||
dim3(1),
|
||||
dim3(1),
|
||||
nelems * sizeof(ck::index_t),
|
||||
data_buf.GetDeviceBuffer(),
|
||||
success_buf.GetDeviceBuffer());
|
||||
|
||||
bool success;
|
||||
success_buf.FromDevice(&success);
|
||||
EXPECT_TRUE(success);
|
||||
}
|
||||
|
||||
TEST(TestTensor, Slicing)
|
||||
{
|
||||
constexpr ck::index_t nelems = 8;
|
||||
|
||||
std::array<ck::index_t, nelems> data;
|
||||
const auto shape = ck::make_tuple(ck::make_tuple(2, 2), 2);
|
||||
const auto strides = ck::make_tuple(ck::make_tuple(1, 2), 4);
|
||||
const auto layout = ck::wrapper::make_layout(shape, strides);
|
||||
auto tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(&data[0], layout);
|
||||
InitTensor(tensor);
|
||||
|
||||
auto tensor2x2x2 =
|
||||
tensor(ck::make_tuple(ck::wrapper::slice(2), ck::wrapper::slice(2)), ck::wrapper::slice(2));
|
||||
EXPECT_EQ(ck::wrapper::rank(tensor2x2x2), 2);
|
||||
EXPECT_EQ(ck::wrapper::depth(tensor2x2x2), 2);
|
||||
EXPECT_EQ(ck::wrapper::size(tensor2x2x2), 8);
|
||||
EXPECT_TRUE(TestTensorCheck1d(tensor2x2x2));
|
||||
|
||||
auto tensor2x2 = tensor(ck::make_tuple(1, ck::wrapper::slice(2)), ck::wrapper::slice(2));
|
||||
EXPECT_EQ(ck::wrapper::rank(tensor2x2), 2);
|
||||
EXPECT_EQ(ck::wrapper::depth(tensor2x2), 2);
|
||||
EXPECT_EQ(ck::wrapper::size(tensor2x2), 4);
|
||||
EXPECT_TRUE(TestTensorCheck1d(tensor2x2, layout(ck::make_tuple(ck::make_tuple(1, 0), 0))));
|
||||
|
||||
auto tensor1x1 = tensor(ck::make_tuple(1, ck::wrapper::slice(1, 2)), ck::wrapper::slice(1, 2));
|
||||
EXPECT_EQ(rank(tensor1x1), 2);
|
||||
EXPECT_EQ(depth(tensor1x1), 2);
|
||||
EXPECT_EQ(size(tensor1x1), 1);
|
||||
EXPECT_TRUE(TestTensorCheck1d(tensor1x1, layout(ck::make_tuple(ck::make_tuple(1, 1), 1))));
|
||||
|
||||
auto tensor2 = tensor(ck::make_tuple(1, 1), ck::wrapper::slice(0, 2));
|
||||
EXPECT_EQ(ck::wrapper::rank(tensor2), 1);
|
||||
EXPECT_EQ(ck::wrapper::depth(tensor2), 1);
|
||||
EXPECT_EQ(ck::wrapper::size(tensor2), 2);
|
||||
EXPECT_TRUE(TestTensorCheck1d(tensor2, layout(ck::make_tuple(ck::make_tuple(1, 1), 0))));
|
||||
|
||||
// negative indexing
|
||||
auto tensor1x2 = tensor(ck::make_tuple(1, ck::wrapper::slice(0, -2)), ck::wrapper::slice());
|
||||
EXPECT_EQ(rank(tensor1x2), 2);
|
||||
EXPECT_EQ(depth(tensor1x2), 2);
|
||||
EXPECT_EQ(size(tensor1x2), 2);
|
||||
EXPECT_TRUE(TestTensorCheck1d(tensor1x2, layout(ck::make_tuple(ck::make_tuple(1, 0), 0))));
|
||||
}
|
||||
Reference in New Issue
Block a user