mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
rework sequence
This commit is contained in:
@@ -57,7 +57,7 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor<Lengths, Strides>;
|
||||
using Type = ConstantTensorDescriptor;
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
@@ -195,19 +195,14 @@ struct ConstantTensorDescriptor
|
||||
Number<unfold_stride>{} *
|
||||
reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies<index_t>{});
|
||||
|
||||
// left and right lengths
|
||||
constexpr auto lengths_pair = GetLengths().Split(Number<IDim>{});
|
||||
constexpr auto left_lengths = lengths_pair.first;
|
||||
constexpr auto right_lengths = lengths_pair.second.PopFront();
|
||||
|
||||
// left and right strides
|
||||
constexpr auto strides_pair = GetStrides().Split(Number<IDim>{});
|
||||
constexpr auto left_strides = strides_pair.first;
|
||||
constexpr auto right_strides = strides_pair.second.PopFront();
|
||||
// left and right
|
||||
constexpr auto left = make_increasing_sequence(Number<0>{}, Number<IDim>{}, Number<1>{});
|
||||
constexpr auto right = make_increasing_sequence(
|
||||
Number<IDim + 1>{}, Number<GetNumOfDimension()>{}, Number<1>{});
|
||||
|
||||
return make_ConstantTensorDescriptor(
|
||||
left_lengths.Append(fold_lengths).Append(right_lengths),
|
||||
left_strides.Append(fold_strides).Append(right_strides));
|
||||
GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right)),
|
||||
GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right)));
|
||||
}
|
||||
|
||||
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
|
||||
@@ -228,40 +223,28 @@ struct ConstantTensorDescriptor
|
||||
"wrong! dimensions to be unfolded need to be packed");
|
||||
});
|
||||
|
||||
// lengths
|
||||
constexpr auto lens_pair1 = Lengths{}.Split(Number<LastUnfoldDim + 1>{});
|
||||
// left and right
|
||||
constexpr auto left =
|
||||
make_increasing_sequence(Number<0>{}, Number<FirstUnfoldDim>{}, Number<1>{});
|
||||
constexpr auto middle = make_increasing_sequence(
|
||||
Number<FirstUnfoldDim>{}, Number<LastUnfoldDim + 1>{}, Number<1>{});
|
||||
constexpr auto right = make_increasing_sequence(
|
||||
Number<LastUnfoldDim + 1>{}, Number<GetNumOfDimension()>{}, Number<1>{});
|
||||
|
||||
constexpr auto right_lengths = lens_pair1.second;
|
||||
// length and stride
|
||||
constexpr index_t unfold_length = accumulate_on_sequence(
|
||||
GetLengths().Extract(middle), std::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr auto lens_pair2 = lens_pair1.first.Split(Number<FirstUnfoldDim>{});
|
||||
constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
|
||||
|
||||
constexpr auto left_lengths = lens_pair2.first;
|
||||
|
||||
constexpr auto fold_lengths = lens_pair2.second;
|
||||
|
||||
constexpr index_t unfold_length =
|
||||
accumulate_on_sequence(fold_lengths, std::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr auto new_lengths =
|
||||
left_lengths.PopBack(Number<unfold_length>{}).Append(right_lengths);
|
||||
|
||||
// strides
|
||||
constexpr auto strides_pair1 = Strides{}.Split(Number<LastUnfoldDim + 1>{});
|
||||
|
||||
constexpr auto right_strides = strides_pair1.second;
|
||||
|
||||
constexpr auto strides_pair2 = strides_pair1.first.Split(Number<FirstUnfoldDim>{});
|
||||
|
||||
constexpr auto left_strides = strides_pair2.first;
|
||||
|
||||
constexpr auto fold_strides = strides_pair2.second;
|
||||
|
||||
constexpr index_t unfold_stride = fold_strides.Back();
|
||||
|
||||
constexpr auto new_strides =
|
||||
left_strides.PushBack(Number<unfold_stride>{}).Append(right_strides);
|
||||
|
||||
return make_ConstantTensorDescriptor(new_lengths, new_strides);
|
||||
return make_ConstantTensorDescriptor(GetLengths()
|
||||
.Extract(left)
|
||||
.PushBack(Number<unfold_length>{})
|
||||
.Append(GetLengths().Extract(right)),
|
||||
GetStrides()
|
||||
.Extract(left)
|
||||
.PushBack(Number<unfold_stride>{})
|
||||
.Append(GetStrides().Extract(right)));
|
||||
}
|
||||
|
||||
template <index_t... IRs>
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
#include "constant_integral.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
|
||||
struct EmptySequence;
|
||||
|
||||
template <index_t... Is>
|
||||
struct Sequence
|
||||
{
|
||||
using Type = Sequence<Is...>;
|
||||
using Type = Sequence;
|
||||
|
||||
static constexpr index_t mSize = sizeof...(Is);
|
||||
|
||||
@@ -72,101 +70,62 @@ struct Sequence
|
||||
return Sequence<Is..., Xs...>{};
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Append(EmptySequence) const;
|
||||
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ constexpr auto Extract(Number<Ns>...) const
|
||||
{
|
||||
return Sequence<Get(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t N>
|
||||
struct split_impl
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ constexpr auto Extract(Sequence<Ns...>) const
|
||||
{
|
||||
template <class FirstSeq, class SecondSeq>
|
||||
__host__ __device__ constexpr auto operator()(FirstSeq, SecondSeq) const
|
||||
{
|
||||
constexpr index_t new_first = FirstSeq{}.PushBack(Number<SecondSeq{}.Front()>{});
|
||||
constexpr index_t new_second = SecondSeq{}.PopFront();
|
||||
|
||||
static_if<(N > 0)>{}([&](auto fwd) {
|
||||
return split_impl<N - 1>{}(new_first, fwd(new_second));
|
||||
}).else_([&](auto fwd) { return std::make_pair(new_first, fwd(new_second)); });
|
||||
}
|
||||
};
|
||||
|
||||
// split one sequence to two sequnces: [0, I) and [I, mSize)
|
||||
// return type is std::pair
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto Split(Number<I>) const;
|
||||
|
||||
template <index_t I, index_t X>
|
||||
__host__ __device__ constexpr auto Modify(Number<I>, Number<X>) const
|
||||
{
|
||||
constexpr auto first_second = Split(Number<I>{});
|
||||
|
||||
constexpr auto left = first_second.first;
|
||||
constexpr auto right = first_second.second.PopFront();
|
||||
|
||||
return left.PushBack(Number<X>{}).Append(right);
|
||||
return Sequence<Get(Number<Ns>{})...>{};
|
||||
}
|
||||
};
|
||||
|
||||
struct EmptySequence
|
||||
template <class, class>
|
||||
struct sequence_merge;
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSize() { return 0; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto PushFront(Number<I>) const
|
||||
{
|
||||
return Sequence<I>{};
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto PushBack(Number<I>) const
|
||||
{
|
||||
return Sequence<I>{};
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
__host__ __device__ constexpr Seq Append(Seq) const
|
||||
{
|
||||
return Seq{};
|
||||
}
|
||||
using Type = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::Append(EmptySequence) const
|
||||
template <index_t IBegin, index_t NSize, index_t Increment>
|
||||
struct increasing_sequence_gen
|
||||
{
|
||||
return Type{};
|
||||
}
|
||||
static constexpr index_t NSizeLeft = NSize / 2;
|
||||
|
||||
// split one sequence to two sequnces: [0, I) and [I, mSize)
|
||||
// return type is std::pair
|
||||
template <index_t... Is>
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::Split(Number<I>) const
|
||||
using Type =
|
||||
sequence_merge<typename increasing_sequence_gen<IBegin, NSizeLeft, Increment>::Type,
|
||||
typename increasing_sequence_gen<IBegin + NSizeLeft * Increment,
|
||||
NSize - NSizeLeft,
|
||||
Increment>::Type>;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t Increment>
|
||||
struct increasing_sequence_gen<IBegin, 1, Increment>
|
||||
{
|
||||
static_assert(I <= GetSize(), "wrong! split position is too high!");
|
||||
using Type = Sequence<IBegin>;
|
||||
};
|
||||
|
||||
static_if<(I == 0)>{}([&](auto fwd) { return std::make_pair(EmptySequence{}, fwd(Type{})); });
|
||||
template <index_t IBegin, index_t Increment>
|
||||
struct increasing_sequence_gen<IBegin, 0, Increment>
|
||||
{
|
||||
using Type = Sequence<>;
|
||||
};
|
||||
|
||||
static_if<(I == GetSize())>{}(
|
||||
[&](auto fwd) { return std::make_pair(Type{}, fwd(EmptySequence{})); });
|
||||
|
||||
static_if<(I > 0 && I < GetSize())>{}(
|
||||
[&](auto fwd) { return split_impl<I>{}(EmptySequence{}, fwd(Type{})); });
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <index_t IBegin, index_t IEnd, index_t Increment>
|
||||
__host__ __device__ auto make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
|
||||
__host__ __device__ constexpr auto
|
||||
make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
|
||||
{
|
||||
static_assert(IBegin < IEnd, (IEnd - IBegin) % Increment == 0, "wrong!");
|
||||
static_assert(IBegin <= IEnd && Increment > 0, "wrong!");
|
||||
|
||||
// not implemented
|
||||
constexpr index_t NSize = (IEnd - IBegin) / Increment;
|
||||
|
||||
return increasing_sequence_gen<IBegin, NSize, Increment>{};
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
|
||||
@@ -222,7 +181,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
constexpr auto seq_x = Sequence<Xs...>{};
|
||||
|
||||
#if 0
|
||||
#if 0 // doesn't compile
|
||||
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
|
||||
constexpr auto I = decltype(Iter){};
|
||||
static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
|
||||
|
||||
Reference in New Issue
Block a user