diff --git a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp index e353b060c6..e59d8e9a67 100644 --- a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp @@ -419,6 +419,13 @@ struct ConstantTensorDescriptor return ConstantTensorDescriptor{}; } + + template + __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) + { + return ConstantTensorDescriptor{}; + } }; template diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index b3e659970d..d7812f8680 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -74,8 +74,8 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( constexpr auto data_multi_id_in_access_order = access_multi_id.Modify(Number{}, Number{}); - constexpr auto data_multi_id = reorder_array_given_old2new( - sequence2array(data_multi_id_in_access_order), DimAccessOrder{}); + constexpr auto data_multi_id = + data_multi_id_in_access_order.ReorderGivenOld2New(DimAccessOrder{}); const index_t src_index = SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id); diff --git a/composable_kernel/include/utility/Sequence.hpp b/composable_kernel/include/utility/Sequence.hpp index 44cfd669db..1d8467afb0 100644 --- a/composable_kernel/include/utility/Sequence.hpp +++ b/composable_kernel/include/utility/Sequence.hpp @@ -6,12 +6,27 @@ namespace ck { +template +struct Sequence; + +template +struct sequence_split; + template -struct is_valid_sequence_map; +struct sequence_reverse; template struct sequence_map_inverse; +template +struct is_valid_sequence_map; + +template +__host__ __device__ constexpr auto sequence_pop_front(Sequence); + +template +__host__ __device__ constexpr auto sequence_pop_back(Seq); + template struct Sequence { @@ -71,7 +86,10 @@ struct Sequence return ReorderGivenNew2Old(typename sequence_map_inverse::type{}); } - __host__ __device__ static constexpr auto Reverse(); + __host__ __device__ static constexpr auto Reverse() + { + return typename sequence_reverse::type{}; + } __host__ __device__ static constexpr auto Front() { @@ -85,9 +103,9 @@ struct Sequence return Get(Number{}); } - __host__ __device__ static constexpr auto PopFront(); + __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } - __host__ __device__ static constexpr auto PopBack(); + __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); } template __host__ __device__ static constexpr auto PushFront(Sequence) @@ -126,7 +144,16 @@ struct Sequence } template - __host__ __device__ static constexpr auto Modify(Number, Number); + __host__ __device__ static constexpr auto Modify(Number, Number) + { + static_assert(I < GetSize(), "wrong!"); + + using seq_split = sequence_split; + constexpr auto seq_left = typename seq_split::SeqType0{}; + constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); + + return seq_left.PushBack(Number{}).PushBack(seq_right); + } template __host__ __device__ static constexpr auto Transform(F f) @@ -283,7 +310,8 @@ template struct sequence_map_inverse_impl { private: - static constexpr auto new_y2x = WorkingY2X::Modify(X2Y{}[XBegin], XBegin); + static constexpr auto new_y2x = + WorkingY2X::Modify(X2Y::Get(Number{}), Number{}); public: using type = @@ -417,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence) template __host__ __device__ constexpr auto sequence_pop_back(Seq) { - static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!"); - return sequence_pop_front(Seq{}.Reverse()).Reverse(); + static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!"); + return sequence_pop_front(Seq::Reverse()).Reverse(); } template @@ -458,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number{}).Reverse(); } -template -__host__ __device__ constexpr auto Sequence::PopFront() -{ - return sequence_pop_front(Type{}); -} - -template -__host__ __device__ constexpr auto Sequence::PopBack() -{ - return sequence_pop_back(Type{}); -} - -template -__host__ __device__ constexpr auto Sequence::Reverse() -{ - return typename sequence_reverse>::type{}; -} - -template -template -__host__ __device__ constexpr auto Sequence::Modify(Number, Number) -{ - static_assert(I < GetSize(), "wrong!"); - - using seq_split = sequence_split; - constexpr auto seq_left = typename seq_split::SeqType0{}; - constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); - - return seq_left.PushBack(Number{}).PushBack(seq_right); -} - template __host__ __device__ void print_Sequence(const char* s, Sequence) {