// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #endif #include "ck/utility/integral_constant.hpp" #include "ck/utility/type.hpp" #include "ck/utility/functional.hpp" #include "ck/utility/math.hpp" namespace ck { template struct static_for; template struct Sequence; template struct sequence_split; template struct sequence_reverse; template struct sequence_map_inverse; template struct is_valid_sequence_map; template __host__ __device__ constexpr auto sequence_pop_front(Sequence); template __host__ __device__ constexpr auto sequence_pop_back(Seq); template struct Sequence { using Type = Sequence; using data_type = index_t; static constexpr index_t mSize = sizeof...(Is); __host__ __device__ static constexpr auto Size() { return Number{}; } __host__ __device__ static constexpr auto GetSize() { return Size(); } __host__ __device__ static constexpr index_t At(index_t I) { // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 const index_t mData[mSize + 1] = {Is..., 0}; return mData[I]; } template __host__ __device__ static constexpr auto At(Number) { static_assert(I < mSize, "wrong! I too large"); return Number{}; } template __host__ __device__ static constexpr auto Get(Number) { return At(Number{}); } template __host__ __device__ constexpr auto operator[](I i) const { return At(i); } template __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) { static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! reorder map should have the same size as Sequence to be rerodered"); static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); return Sequence{})...>{}; } // MapOld2New is Sequence<...> template __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) { static_assert(MapOld2New::Size() == Size(), "wrong! reorder map should have the same size as Sequence to be rerodered"); static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); return ReorderGivenNew2Old(typename sequence_map_inverse::type{}); } __host__ __device__ static constexpr auto Reverse() { return typename sequence_reverse::type{}; } __host__ __device__ static constexpr auto Front() { static_assert(mSize > 0, "wrong!"); return At(Number<0>{}); } __host__ __device__ static constexpr auto Back() { static_assert(mSize > 0, "wrong!"); return At(Number{}); } __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); } template __host__ __device__ static constexpr auto PushFront(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto PushFront(Number...) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Sequence) { return Sequence{}; } template __host__ __device__ static constexpr auto PushBack(Number...) { return Sequence{}; } template __host__ __device__ static constexpr auto Extract(Number...) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Extract(Sequence) { return Sequence{})...>{}; } template __host__ __device__ static constexpr auto Modify(Number, Number) { static_assert(I < Size(), "wrong!"); using seq_split = sequence_split; constexpr auto seq_left = typename seq_split::left_type{}; constexpr auto seq_right = typename seq_split::right_type{}.PopFront(); return seq_left.PushBack(Number{}).PushBack(seq_right); } template __host__ __device__ static constexpr auto Transform(F f) { return Sequence{}; } __host__ __device__ static void Print() { printf("{"); printf("size %d, ", index_t{Size()}); static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); }); printf("}"); } }; namespace impl { template struct __integer_sequence; template struct __integer_sequence { using seq_type = Sequence; }; } // namespace impl template using make_index_sequence = typename __make_integer_seq::seq_type; // merge sequence template struct sequence_merge { using type = typename sequence_merge::type>::type; }; template struct sequence_merge, Sequence> { using type = Sequence; }; template struct sequence_merge { using type = Seq; }; // generate sequence template struct sequence_gen { template struct sequence_gen_impl { static constexpr index_t NRemainLeft = NRemain / 2; static constexpr index_t NRemainRight = NRemain - NRemainLeft; static constexpr index_t IMiddle = IBegin + NRemainLeft; using type = typename sequence_merge< typename sequence_gen_impl::type, typename sequence_gen_impl::type>::type; }; template struct sequence_gen_impl { static constexpr index_t Is = G{}(Number{}); using type = Sequence; }; template struct sequence_gen_impl { using type = Sequence<>; }; using type = typename sequence_gen_impl<0, NSize, F>::type; }; // arithmetic sequence template struct arithmetic_sequence_gen { struct F { __host__ __device__ constexpr index_t operator()(index_t i) const { return i * Increment + IBegin; } }; using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; using type1 = Sequence<>; static constexpr bool kHasContent = (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd); using type = typename conditional::type; }; template struct arithmetic_sequence_gen<0, IEnd, 1> { template struct WrapSequence { using type = Sequence; }; // https://reviews.llvm.org/D13786 using type = typename __make_integer_seq::type; }; // uniform sequence template struct uniform_sequence_gen { struct F { __host__ __device__ constexpr index_t operator()(index_t) const { return I; } }; using type = typename sequence_gen::type; }; // reverse inclusive scan (with init) sequence template struct sequence_reverse_inclusive_scan; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); using type = typename sequence_merge, old_scan>::type; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { using type = Sequence<>; }; // split sequence template struct sequence_split { static constexpr index_t NSize = Seq{}.Size(); using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; using range1 = typename arithmetic_sequence_gen::type; using left_type = decltype(Seq::Extract(range0{})); using right_type = decltype(Seq::Extract(range1{})); }; // reverse sequence template struct sequence_reverse { static constexpr index_t NSize = Seq{}.Size(); using seq_split = sequence_split; using type = typename sequence_merge< typename sequence_reverse::type, typename sequence_reverse::type>::type; }; template struct sequence_reverse> { using type = Sequence; }; template struct sequence_reverse> { using type = Sequence; }; #if 1 template struct sequence_reduce { using type = typename sequence_reduce::type>::type; }; template struct sequence_reduce, Sequence> { using type = Sequence; }; template struct sequence_reduce { using type = Seq; }; #endif // Implement sequence_sort and sequence_unique_sort using constexpr functions (C++17) namespace sort_impl { // Temporary arrays to hold values during operations with capacity N and mutable size. template struct IndexedValueArray { index_t values[N > 0 ? N : 1]; index_t ids[N > 0 ? N : 1]; index_t size = 0; }; template constexpr auto make_indexed_value_array(Sequence) { constexpr index_t N = sizeof...(Is); IndexedValueArray result = {{Is...}, {}, N}; for(index_t i = 0; i < N; ++i) { result.ids[i] = i; } return result; } enum class SortField { Values, Ids }; // Perform an insertion sort on an IndexedValueArray. template constexpr auto insertion_sort(IndexedValueArray arr, Compare comp) { for(index_t i = 1; i < arr.size; ++i) { index_t key_val = arr.values[i]; index_t key_id = arr.ids[i]; index_t j = i - 1; while(j >= 0 && comp(key_val, arr.values[j])) { arr.values[j + 1] = arr.values[j]; arr.ids[j + 1] = arr.ids[j]; --j; } arr.values[j + 1] = key_val; arr.ids[j + 1] = key_id; } return arr; } // Remove duplicates from a sorted IndexedValueArray. template constexpr auto unique(const IndexedValueArray& sorted, Equal eq) { IndexedValueArray result{}; if constexpr(N == 0) { return result; } result.size = 1; result.values[0] = sorted.values[0]; result.ids[0] = sorted.ids[0]; for(index_t i = 1; i < sorted.size; ++i) { if(!eq(sorted.values[i], sorted.values[i - 1])) { result.values[result.size] = sorted.values[i]; result.ids[result.size] = sorted.ids[i]; ++result.size; } } return result; } // Compute sorted (and optionally unique) IndexedValueArray from input Sequence. template constexpr auto compute_sorted(Sequence seq, Compare comp, Equal eq) { auto sorted = insertion_sort(make_indexed_value_array(seq), comp); return Unique ? unique(sorted, eq) : sorted; } // Cache the sorted results to avoid recomputation. template struct SortedCache { static constexpr auto data = compute_sorted(Seq{}, Compare{}, Equal{}); }; // Build sorted value and ID sequences from cached sorted data template constexpr index_t get_sorted_field() { constexpr auto& data = SortedCache::data; return (Field == SortField::Values) ? data.values[I] : data.ids[I]; } template struct SortedSequences; template struct SortedSequences> { using values_type = Sequence()...>; using ids_type = Sequence()...>; }; template using sorted_sequences_t = SortedSequences< Unique, Seq, Compare, Equal, typename arithmetic_sequence_gen<0, SortedCache::data.size, 1>:: type>; using Equal = ck::math::equal; } // namespace sort_impl template struct sequence_sort { using sorted_seqs = sort_impl::sorted_sequences_t; using type = typename sorted_seqs::values_type; using sorted2unsorted_map = typename sorted_seqs::ids_type; }; template struct sequence_unique_sort { using sorted_seqs = sort_impl::sorted_sequences_t; using type = typename sorted_seqs::values_type; using sorted2unsorted_map = typename sorted_seqs::ids_type; }; template struct is_valid_sequence_map : is_same::type, typename sequence_sort>::type> { }; template struct sequence_map_inverse { template struct sequence_map_inverse_impl { static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number{}), Number{}); using type = typename sequence_map_inverse_impl:: type; }; template struct sequence_map_inverse_impl { using type = WorkingY2X; }; using type = typename sequence_map_inverse_impl::type, 0, SeqMap::Size()>::type; }; template __host__ __device__ constexpr bool operator==(Sequence, Sequence) { return ((Xs == Ys) && ...); } template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs + Ys)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs - Ys)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs * Ys)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs / Ys)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Sequence) { static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); return Sequence<(Xs % Ys)...>{}; } template __host__ __device__ constexpr auto operator+(Sequence, Number) { return Sequence<(Xs + Y)...>{}; } template __host__ __device__ constexpr auto operator-(Sequence, Number) { return Sequence<(Xs - Y)...>{}; } template __host__ __device__ constexpr auto operator*(Sequence, Number) { return Sequence<(Xs * Y)...>{}; } template __host__ __device__ constexpr auto operator/(Sequence, Number) { return Sequence<(Xs / Y)...>{}; } template __host__ __device__ constexpr auto operator%(Sequence, Number) { return Sequence<(Xs % Y)...>{}; } template __host__ __device__ constexpr auto operator+(Number, Sequence) { return Sequence<(Y + Xs)...>{}; } template __host__ __device__ constexpr auto operator-(Number, Sequence) { return Sequence<(Y - Xs)...>{}; } template __host__ __device__ constexpr auto operator*(Number, Sequence) { return Sequence<(Y * Xs)...>{}; } template __host__ __device__ constexpr auto operator/(Number, Sequence) { return Sequence<(Y / Xs)...>{}; } template __host__ __device__ constexpr auto operator%(Number, Sequence) { return Sequence<(Y % Xs)...>{}; } template __host__ __device__ constexpr auto sequence_pop_front(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto sequence_pop_back(Seq) { static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!"); return sequence_pop_front(Seq::Reverse()).Reverse(); } template __host__ __device__ constexpr auto merge_sequences(Seqs...) { return typename sequence_merge::type{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence, Sequence) { static_assert(Sequence::mSize == Sequence::mSize && Sequence::mSize == Sequence::mSize, "Dim not the same"); return Sequence{}; } template __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) { return typename sequence_reverse_inclusive_scan::type{}; } template __host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number) { return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number{}) .PushBack(Number{}); } template __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) { return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number{}).Reverse(); } template __host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence /* ids */) { return Sequence{})...>{}; } #if 1 namespace detail { template struct pick_sequence_elements_by_mask_impl { using new_work_seq = typename conditional::type; using type = typename pick_sequence_elements_by_mask_impl::type; }; template struct pick_sequence_elements_by_mask_impl, Sequence<>> { using type = WorkSeq; }; } // namespace detail template __host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask) { static_assert(Seq::Size() == Mask::Size(), "wrong!"); return typename detail::pick_sequence_elements_by_mask_impl, Seq, Mask>::type{}; } namespace detail { template struct modify_sequence_elements_by_ids_impl { using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front())); using type = typename modify_sequence_elements_by_ids_impl::type; }; template struct modify_sequence_elements_by_ids_impl, Sequence<>> { using type = WorkSeq; }; } // namespace detail template __host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids) { static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!"); return typename detail::modify_sequence_elements_by_ids_impl::type{}; } #endif template __host__ __device__ constexpr index_t reduce_on_sequence(Seq, Reduce f, Number /*initial_value*/) { index_t result = Init; for(index_t i = 0; i < Seq::Size(); ++i) { result = f(result, Seq::At(i)); } return result; } // TODO: a generic any_of for any container template __host__ __device__ constexpr bool sequence_any_of(Seq, F f) { bool flag = false; for(index_t i = 0; i < Seq::Size(); ++i) { flag = flag || f(Seq::At(i)); } return flag; } // TODO: a generic all_of for any container template __host__ __device__ constexpr bool sequence_all_of(Seq, F f) { bool flag = true; for(index_t i = 0; i < Seq::Size(); ++i) { flag = flag && f(Seq::At(i)); } return flag; } template using sequence_merge_t = typename sequence_merge::type; template using uniform_sequence_gen_t = typename uniform_sequence_gen::type; } // namespace ck #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template std::ostream& operator<<(std::ostream& os, const ck::Sequence) { using S = ck::Sequence; os << "{"; ck::static_for<0, S::Size() - ck::Number<1>{}, 1>{}( [&](auto i) { os << S::At(i).value << ", "; }); os << S::At(S::Size() - ck::Number<1>{}).value << "}"; return os; } #endif