mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
adding dimension transformation
This commit is contained in:
28
composable_kernel/include/tensor_description/dimension.hpp
Normal file
28
composable_kernel/include/tensor_description/dimension.hpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#ifndef CK_DIMENSION_HPP
|
||||
#define CK_DIMENSION_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t Length>
|
||||
struct Dimension
|
||||
{
|
||||
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
|
||||
};
|
||||
|
||||
template <index_t Length, index_t Stride>
|
||||
struct NativeDimension : Dimension<Length>
|
||||
{
|
||||
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(index_t id) { return id * Stride; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffsetDiff(index_t id_diff)
|
||||
{
|
||||
return id_diff * Stride;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,217 @@
|
||||
#ifndef CK_DIMENSION_TRANSFORM_HPP
|
||||
#define CK_DIMENSION_TRANSFORM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t N>
|
||||
using MultiIndex = Array<index_t, N>;
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths>
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
|
||||
using LowerId = MultiIndex<nDim>;
|
||||
using UpperId = LowerId;
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension()
|
||||
{
|
||||
return GetLowerNumOfDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
|
||||
{
|
||||
return id_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths, class LeftPads, class RightPads>
|
||||
struct Pad
|
||||
{
|
||||
static constexpr index_t nDim = LowLengths::GetSize();
|
||||
|
||||
using LowerId = MultiIndex<nDim>;
|
||||
using UpperId = LowerId;
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension()
|
||||
{
|
||||
return GetLowerNumOfDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return GetLowerLengths() + LeftPads + RightPads;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up - LeftPads; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
|
||||
{
|
||||
return id_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
// LowLengths: Sequence<...>
|
||||
template <class LowLengths>
|
||||
struct Merge
|
||||
{
|
||||
static constexpr index_t nDimLow = LowLengths::GetSize();
|
||||
static constexpr index_t nDimUp = 1;
|
||||
|
||||
using LowerId = MultiIndex<nDimLow>;
|
||||
using UpperId = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return Sequence<accumulate_on_sequence(
|
||||
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
|
||||
{
|
||||
LowerId id_low;
|
||||
|
||||
// not implemeneted
|
||||
|
||||
return id_low;
|
||||
}
|
||||
|
||||
// id_low_diff depends on id_low_old, so id_low need to be up-to-date
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff, LowerId id_low_old)
|
||||
{
|
||||
LowerId id_low_diff;
|
||||
|
||||
// not implemeneted
|
||||
|
||||
return id_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
};
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
template <index_t LowLength, class UpLengths>
|
||||
struct Unmerge
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr Unmerge()
|
||||
{
|
||||
static_assert(LowLength == accumulate_on_sequence(
|
||||
UpLengths{}, math::multiplies<index_t>{}, Number<1>{}),
|
||||
"wrong! UpLengths need to be ");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
|
||||
{
|
||||
constexpr auto scans = typename sequence_reverse_inclusive_scan<UpLengths,
|
||||
math::multiplies<index_t>,
|
||||
1>::type{};
|
||||
|
||||
LowerId id_low{0};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) { id_low[0] += id_up[idim] * scans[idim]; });
|
||||
|
||||
return id_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
|
||||
{
|
||||
return GetLowerId(id_up_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
// UpLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// id_low = coefficients[0, ...nDimUp-1] * id_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
template <index_t LowLength, class UpLengths, class Coefficients>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpLengths::GetSize();
|
||||
|
||||
static constexpr auto mCoefficients = Coefficients{};
|
||||
|
||||
__host__ __device__ constexpr Embed()
|
||||
{
|
||||
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
|
||||
"wrong! # of dimensions not consistent");
|
||||
|
||||
constexpr index_t low_id_max =
|
||||
Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
|
||||
math::plus<index_t>{},
|
||||
Number<0>{});
|
||||
|
||||
static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
|
||||
{
|
||||
LowerId id_low{mCoefficients[nDimUp]};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { id_low[0] += id_up[idim] * mCoefficients[idim]; });
|
||||
|
||||
return id_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
|
||||
{
|
||||
LowerId id_low_diff{0};
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { id_low_diff[0] += id_up_diff[idim] * mCoefficients[idim]; });
|
||||
|
||||
return id_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,104 @@
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dimension.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
struct NativeTensorDescriptor
|
||||
{
|
||||
using type = NativeTensorDescriptor;
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
|
||||
using Id = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(Id id)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
};
|
||||
|
||||
// LowerTensorDescriptor
|
||||
// Transforms: std::tuple<DimensionTransforms...>
|
||||
// LowerIds: std::tuple<Sequence<...>>
|
||||
// UpperIds: std::tuple<Sequence<...>>
|
||||
template <class LowTensorDescriptor,
|
||||
class Transforms,
|
||||
class LowDimensionMasks,
|
||||
class UpDimensionMasks>
|
||||
struct TransformedTensorDescriptor
|
||||
{
|
||||
using type = TransformedTensorDescriptor;
|
||||
static constexpr index_t nDimUp = xxxx;
|
||||
static constexpr index_t nDimLow = xxx;
|
||||
|
||||
static constexpr index_t nTransform = Transforms::GetSize();
|
||||
|
||||
using UpperId = MultiIndex<nDimUp>;
|
||||
using LowerId = MultiIndex<nDimLow>;
|
||||
|
||||
__host__ __device__ static constexpr TransformedTensorDescriptor()
|
||||
{
|
||||
static_assert(nTransform == Transforms::GetSize() &&
|
||||
nTransform == LowDimensionMasks::GetSize() &&
|
||||
nTransform == UpDimensionMasks::GetSize(),
|
||||
"wrong! # of transformations not the same");
|
||||
|
||||
// TODO: sanity check: LowDimensionMasks should include all low-dimensions,
|
||||
// UpDimensionMasks should include all up-dimensions
|
||||
|
||||
// TODO: sanity check: while a up-dimension could be associated with multille
|
||||
// transformation,
|
||||
// a low-dimension should be associated with only one transformation
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension()
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
|
||||
{
|
||||
return LowTensorDescriptor{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetLowerId(UpperId id_up)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetOffset(UpperId id_up)
|
||||
{
|
||||
return GetLowerTensorDescriptor().GetOffset(GetLowerId(id_up));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto AreUpperId2OffsetLinear();
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetIndependentDimensionGroups()
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "config.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "tuple.hpp"
|
||||
#include "math.hpp"
|
||||
#include "vector_type.hpp"
|
||||
#include "Sequence.hpp"
|
||||
|
||||
68
composable_kernel/include/utility/tuple.hpp
Normal file
68
composable_kernel/include/utility/tuple.hpp
Normal file
@@ -0,0 +1,68 @@
|
||||
#ifndef CK_TUPLE_HPP
|
||||
#define CK_TUPLE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class... Ts>
|
||||
struct tuple : public std::tuple<Ts...>
|
||||
{
|
||||
using type = tuple;
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return std::tuple_size(tuple{}); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto Get(Number<I>) const
|
||||
{
|
||||
return std::get<I>(*this);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator[](Number<I>) const
|
||||
{
|
||||
return Get(Number<I>{}) :
|
||||
}
|
||||
};
|
||||
|
||||
// merge tuple
|
||||
template <class... Tuples>
|
||||
__host__ __device__ constexpr auto merge_tuple(Tuples&&... xs)
|
||||
{
|
||||
return std::tuple_cat(xs...);
|
||||
};
|
||||
|
||||
// generate sequence
|
||||
template <index_t IBegin, index_t NRemain, class F>
|
||||
struct tuple_gen_impl
|
||||
{
|
||||
static constexpr index_t NRemainLeft = NRemain / 2;
|
||||
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
|
||||
static constexpr index_t IMiddle = IBegin + NRemainLeft;
|
||||
|
||||
using type =
|
||||
typename tuple_merge<typename tuple_gen_impl<IBegin, NRemainLeft, F>::type,
|
||||
typename tuple_gen_impl<IMiddle, NRemainRight, F>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I, class F>
|
||||
struct tuple_gen_impl<I, 1, F>
|
||||
{
|
||||
static constexpr auto x = F{}(Number<I>{});
|
||||
using type = tuple<Is>;
|
||||
};
|
||||
|
||||
template <index_t I, class F>
|
||||
struct sequence_gen_impl<I, 0, F>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
template <index_t NSize, class F>
|
||||
struct sequence_gen
|
||||
{
|
||||
using type = typename sequence_gen_impl<0, NSize, F>::type;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user