mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 16:04:58 +00:00
refactoring
This commit is contained in:
@@ -29,12 +29,9 @@ struct Sequence
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t operator[](Number<I>) const
|
||||
__host__ __device__ constexpr auto operator[](Number<I>) const
|
||||
{
|
||||
static_assert(I < mSize, "wrong! I too large");
|
||||
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
return mData[I];
|
||||
return Number<Get(Number<I>{})>{};
|
||||
}
|
||||
|
||||
// make sure I is constepxr
|
||||
@@ -69,24 +66,30 @@ struct Sequence
|
||||
return mData[mSize - 1];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr auto PushFront(Number<I>)
|
||||
{
|
||||
return Sequence<I, Is...>{};
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr auto PushBack(Number<I>)
|
||||
{
|
||||
return Sequence<Is..., I>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto PopFront();
|
||||
|
||||
__host__ __device__ static constexpr auto PopBack();
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ static constexpr auto Append(Sequence<Xs...>)
|
||||
__host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<Xs..., Is...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ static constexpr auto PushFront(Number<Xs>...)
|
||||
{
|
||||
return Sequence<Xs..., Is...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<Is..., Xs...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ static constexpr auto PushBack(Number<Xs>...)
|
||||
{
|
||||
return Sequence<Is..., Xs...>{};
|
||||
}
|
||||
@@ -105,6 +108,12 @@ struct Sequence
|
||||
|
||||
template <index_t I, index_t X>
|
||||
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>);
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ static constexpr auto Transform(F f)
|
||||
{
|
||||
return Sequence<f(Is)...>{};
|
||||
}
|
||||
};
|
||||
|
||||
// merge sequence
|
||||
@@ -114,7 +123,7 @@ struct sequence_merge;
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using SeqType = Sequence<Xs..., Ys...>;
|
||||
using type = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
// arithmetic sqeuence
|
||||
@@ -123,40 +132,29 @@ struct arithmetic_sequence_gen_impl
|
||||
{
|
||||
static constexpr index_t NSizeLeft = NSize / 2;
|
||||
|
||||
using SeqType = typename sequence_merge<
|
||||
typename arithmetic_sequence_gen_impl<IBegin, NSizeLeft, Increment>::SeqType,
|
||||
using type = typename sequence_merge<
|
||||
typename arithmetic_sequence_gen_impl<IBegin, NSizeLeft, Increment>::type,
|
||||
typename arithmetic_sequence_gen_impl<IBegin + NSizeLeft * Increment,
|
||||
NSize - NSizeLeft,
|
||||
Increment>::SeqType>::SeqType;
|
||||
Increment>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t Increment>
|
||||
struct arithmetic_sequence_gen_impl<IBegin, 1, Increment>
|
||||
{
|
||||
using SeqType = Sequence<IBegin>;
|
||||
using type = Sequence<IBegin>;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t Increment>
|
||||
struct arithmetic_sequence_gen_impl<IBegin, 0, Increment>
|
||||
{
|
||||
using SeqType = Sequence<>;
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t IEnd, index_t Increment>
|
||||
struct arithmetic_sequence_gen
|
||||
{
|
||||
using SeqType =
|
||||
typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
|
||||
};
|
||||
|
||||
// transform sequence
|
||||
template <class, class>
|
||||
struct sequence_transform;
|
||||
|
||||
template <class F, index_t... Is>
|
||||
struct sequence_transform<F, Sequence<Is...>>
|
||||
{
|
||||
using SeqType = Sequence<F{}(Is)...>;
|
||||
using type = typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::type;
|
||||
};
|
||||
|
||||
// uniform sequence
|
||||
@@ -168,9 +166,8 @@ struct uniform_sequence_gen
|
||||
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
|
||||
};
|
||||
|
||||
using SeqType = typename sequence_transform<
|
||||
return_constant,
|
||||
typename arithmetic_sequence_gen<0, NSize, 1>::SeqType>::SeqType;
|
||||
using type = decltype(
|
||||
typename arithmetic_sequence_gen<0, NSize, 1>::type{}.Transform(return_constant{}));
|
||||
};
|
||||
|
||||
// reverse inclusive scan (with init) sequence
|
||||
@@ -180,34 +177,23 @@ struct sequence_reverse_inclusive_scan;
|
||||
template <index_t I, index_t... Is, class Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
|
||||
{
|
||||
using old_scan =
|
||||
typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::SeqType;
|
||||
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
|
||||
|
||||
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
|
||||
|
||||
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
|
||||
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
|
||||
};
|
||||
|
||||
template <index_t I, class Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
|
||||
{
|
||||
using SeqType = Sequence<Reduce{}(I, Init)>;
|
||||
using type = Sequence<Reduce{}(I, Init)>;
|
||||
};
|
||||
|
||||
template <class Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
|
||||
{
|
||||
using SeqType = Sequence<>;
|
||||
};
|
||||
|
||||
// extract sequence
|
||||
template <class, class>
|
||||
struct sequence_extract;
|
||||
|
||||
template <class Seq, index_t... Is>
|
||||
struct sequence_extract<Seq, Sequence<Is...>>
|
||||
{
|
||||
using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// split sequence
|
||||
@@ -216,11 +202,11 @@ struct sequence_split
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.GetSize();
|
||||
|
||||
using range0 = typename arithmetic_sequence_gen<0, I, 1>::SeqType;
|
||||
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::SeqType;
|
||||
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
|
||||
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
|
||||
|
||||
using SeqType0 = typename sequence_extract<Seq, range0>::SeqType;
|
||||
using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
|
||||
using SeqType0 = decltype(Seq::Extract(range0{}));
|
||||
using SeqType1 = decltype(Seq::Extract(range1{}));
|
||||
};
|
||||
|
||||
// reverse sequence
|
||||
@@ -230,31 +216,31 @@ struct sequence_reverse
|
||||
static constexpr index_t NSize = Seq{}.GetSize();
|
||||
|
||||
using seq_split = sequence_split<Seq, NSize / 2>;
|
||||
using SeqType = typename sequence_merge<
|
||||
typename sequence_reverse<typename seq_split::SeqType1>::SeqType,
|
||||
typename sequence_reverse<typename seq_split::SeqType0>::SeqType>::SeqType;
|
||||
using type = typename sequence_merge<
|
||||
typename sequence_reverse<typename seq_split::SeqType1>::type,
|
||||
typename sequence_reverse<typename seq_split::SeqType0>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
struct sequence_reverse<Sequence<I>>
|
||||
{
|
||||
using SeqType = Sequence<I>;
|
||||
using type = Sequence<I>;
|
||||
};
|
||||
|
||||
template <index_t I0, index_t I1>
|
||||
struct sequence_reverse<Sequence<I0, I1>>
|
||||
{
|
||||
using SeqType = Sequence<I1, I0>;
|
||||
using type = Sequence<I1, I0>;
|
||||
};
|
||||
|
||||
template <class Seq>
|
||||
struct is_valid_sequence_map
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
|
||||
|
||||
// TODO: add proper check for is_valid, something like:
|
||||
// static constexpr bool value =
|
||||
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::SeqType,
|
||||
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type,
|
||||
// typename sequence_sort<Seq>::SortedSeqType>{};
|
||||
};
|
||||
|
||||
@@ -401,7 +387,7 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::SeqType{};
|
||||
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
@@ -425,7 +411,7 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack()
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::Reverse()
|
||||
{
|
||||
return typename sequence_reverse<Sequence<Is...>>::SeqType{};
|
||||
return typename sequence_reverse<Sequence<Is...>>::type{};
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
@@ -438,7 +424,7 @@ __host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
|
||||
constexpr auto seq_left = typename seq_split::SeqType0{};
|
||||
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
|
||||
|
||||
return seq_left.PushBack(Number<X>{}).Append(seq_right);
|
||||
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
|
||||
Reference in New Issue
Block a user