mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add optimized copy to ck wrapper (#1126)
* Add optimized copy to ck wrapper * Example optimizations * Fixes * Move img2col test to client example * Refactor example * Fix docs * Fixes * Fix * Fixes * Fixes * Fixes * Fixes * Fixes --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -22,14 +22,19 @@ namespace wrapper {
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
// forward declaration
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
struct Layout;
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
namespace {
|
||||
// Generate packed (column-major) strides if not passed
|
||||
/**
|
||||
* \brief Generate packed (column-major) strides if not passed
|
||||
*
|
||||
* \param shape Tensor shape.
|
||||
* \return Generated column-major strides.
|
||||
*/
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr static auto
|
||||
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
|
||||
@@ -50,9 +55,16 @@ GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
|
||||
Number<decltype(unrolled_shape)::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create naive tensor descriptor from nested shape.
|
||||
*
|
||||
* \param shape Tensor shape.
|
||||
* \param strides Tensor strides.
|
||||
* \return Unrolled descriptor
|
||||
*/
|
||||
template <typename LayoutShape, typename LayoutStrides>
|
||||
__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape,
|
||||
const LayoutStrides& strides)
|
||||
{
|
||||
const auto unrolled_shape = UnrollNestedTuple(shape);
|
||||
if constexpr(is_same_v<LayoutStrides, Tuple<>>)
|
||||
@@ -86,8 +98,8 @@ __host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shap
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
|
||||
{
|
||||
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, strides));
|
||||
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -100,15 +112,19 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape)
|
||||
{
|
||||
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, Tuple<>{}));
|
||||
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
// get
|
||||
// Get dim (could be returned from get with empty Idxs)
|
||||
|
||||
/**
|
||||
* \private
|
||||
* \brief Get dim.
|
||||
*
|
||||
* \param dim Dimension.
|
||||
* \return Returned the same dimension.
|
||||
*/
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr get(const T& dim)
|
||||
@@ -178,7 +194,7 @@ __host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
|
||||
},
|
||||
Number<old_shape_dims>{});
|
||||
|
||||
const auto& flatten_desc = layout.GetUnnestedDescriptor();
|
||||
const auto& flatten_desc = layout.GetUnrolledDescriptor();
|
||||
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
|
||||
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
|
||||
}
|
||||
@@ -197,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem)
|
||||
}
|
||||
|
||||
// size
|
||||
// Get dim size (could be returned from get function)
|
||||
/**
|
||||
* \private
|
||||
* \brief Get size.
|
||||
*
|
||||
* \param dim Size.
|
||||
* \return Returned the same size.
|
||||
*/
|
||||
template <typename T>
|
||||
__host__ __device__ T constexpr size(const T& dim)
|
||||
@@ -214,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim)
|
||||
* \param layout Layout to get Shape of.
|
||||
* \return Requsted length.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
template <index_t idx, typename Shape, typename UnrolledDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
|
||||
{
|
||||
return layout.template GetLength<idx>();
|
||||
}
|
||||
@@ -240,8 +259,8 @@ __host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
|
||||
* \param layout Layout to calculate shape size.
|
||||
* \return Requsted size.
|
||||
*/
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
|
||||
{
|
||||
return layout.GetLengths();
|
||||
}
|
||||
@@ -280,9 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem)
|
||||
* \param layout Layout to calculate rank.
|
||||
* \return Requsted rank.
|
||||
*/
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
__host__ __device__ constexpr auto
|
||||
rank([[maybe_unused]] const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
rank([[maybe_unused]] const Layout<Shape, UnrolledDescriptorType>& layout)
|
||||
{
|
||||
return Shape::Size();
|
||||
}
|
||||
@@ -302,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& t
|
||||
|
||||
/**
|
||||
* \private
|
||||
* \brief Rank for scalar
|
||||
*
|
||||
* \param dim Dimension scalar.
|
||||
* \return Returned 1.
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr index_t rank(const Number<IDim>&)
|
||||
__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number<IDim>& dim)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
* \brief Rank for scalar
|
||||
*
|
||||
* \param dim Dimension scalar.
|
||||
* \return Returned 1.
|
||||
*/
|
||||
__host__ __device__ constexpr index_t rank(const index_t&) { return 1; }
|
||||
__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; }
|
||||
|
||||
/**
|
||||
* \brief Hierarchical rank.
|
||||
@@ -334,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
|
||||
* \param layout Layout to calculate depth.
|
||||
* \return Requsted depth.
|
||||
*/
|
||||
template <typename Shape, typename UnnestedDescriptorType>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, UnnestedDescriptorType>& layout)
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
__host__ __device__ constexpr auto depth(const Layout<Shape, UnrolledDescriptorType>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
return TupleDepth(shape);
|
||||
@@ -355,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
|
||||
|
||||
/**
|
||||
* \private
|
||||
* \brief Depth for scalar
|
||||
*
|
||||
* \param dim Scalar.
|
||||
* \return Returned 0.
|
||||
*/
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr index_t depth(const Number<IDim>&)
|
||||
__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number<IDim>& dim)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* \private
|
||||
* \brief Depth for scalar
|
||||
*
|
||||
* \param dim Scalar.
|
||||
* \return Returned 0.
|
||||
*/
|
||||
__host__ __device__ constexpr index_t depth(const index_t&) { return 0; }
|
||||
__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; }
|
||||
|
||||
/**
|
||||
* \brief Hierarchical depth.
|
||||
|
||||
Reference in New Issue
Block a user