Add tensor partition and generic copy for ck wrapper (#1108)

* Add tensor partition and generic copy for ck wrapper

* Update changelog

* Stylistic fixes

* Change shape/strides logic to descriptor transforms

* Fixes

* Fix client example

* Fix comments
This commit is contained in:
Bartłomiej Kocot
2024-01-03 01:10:57 +01:00
committed by GitHub
parent b268f273de
commit 4234b3a691
14 changed files with 940 additions and 306 deletions

View File

@@ -27,12 +27,12 @@ using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template <typename Shape, typename Strides>
template <typename Shape, typename UnnestedDescriptorType>
struct Layout;
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
typename UnnestedDescriptorType,
index_t NumVectors, // params for Register memory
index_t ScalarPerVector // param for Register memory
>
@@ -98,11 +98,19 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType, typename ElementType, typename Shape, typename Strides>
constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& layout)
template <MemoryTypeEnum MemoryType,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType>
constexpr auto make_tensor(ElementType* pointer,
const Layout<Shape, UnnestedDescriptorType>& layout)
{
return Tensor<MemoryType, ElementType, Shape, Strides, 0 /*NumVectors*/, 0 /*ScalarPerVector*/>(
pointer, layout);
return Tensor<MemoryType,
ElementType,
Shape,
UnnestedDescriptorType,
0 /*NumVectors*/,
0 /*ScalarPerVector*/>(pointer, layout);
}
/**
@@ -112,19 +120,21 @@ constexpr auto make_tensor(ElementType* pointer, const Layout<Shape, Strides>& l
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType,
index_t NumVectors,
index_t ScalarPerVector,
typename ElementType,
typename Shape,
typename Strides>
constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
typename ElementType>
constexpr auto make_register_tensor()
{
static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported");
return Tensor<MemoryType, ElementType, Shape, Strides, NumVectors, ScalarPerVector>(layout);
const auto layout = make_layout(make_tuple(Number<NumVectors>{}), make_tuple(Number<1>{}));
return Tensor<MemoryType,
ElementType,
Tuple<Number<NumVectors>>,
std::remove_const_t<remove_reference_t<decltype(layout.GetUnnestedDescriptor())>>,
NumVectors,
ScalarPerVector>(layout);
}
/**
@@ -136,12 +146,15 @@ constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
layout(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
__host__ __device__ constexpr const auto& layout(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{
return tensor.GetLayout();
}
@@ -157,12 +170,15 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
size(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
__host__ __device__ constexpr auto size(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{
return size<Idxs...>(tensor.GetLayout());
}
@@ -178,12 +194,15 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
rank(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
__host__ __device__ constexpr auto rank(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{
return rank<Idxs...>(tensor.GetLayout());
}
@@ -199,35 +218,19 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr index_t
depth(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
__host__ __device__ constexpr auto depth(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{
return depth<Idxs...>(tensor.GetLayout());
}
/**
* \brief Get Tensor strides.
*
* \param tensor Tensor to get strides from.
* \return Requsted strides.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
{
return stride(tensor.GetLayout());
}
/**
* \brief Get Tensor shape.
*
@@ -237,12 +240,15 @@ stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors,
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename Strides,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto&
shape(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors, ScalarPerVector>&
tensor)
__host__ __device__ constexpr const auto& shape(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
{
return shape(tensor.GetLayout());
}