diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp index 3636db9fa6..042cb8f91d 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp @@ -440,7 +440,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw wo_block_data_begin + wo_thread_data_begin), make_zero_array(), out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread), - arithmetic_sequence_gen<0, 10, 1>::SeqType{}, + arithmetic_sequence_gen<0, 10, 1>::type{}, Number<1>{}); #endif }); diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp index 5a7267d75c..ff5246435b 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp @@ -491,7 +491,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer wo_block_data_begin + wo_thread_data_begin), make_zero_array(), out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread), - arithmetic_sequence_gen<0, 10, 1>::SeqType{}, + arithmetic_sequence_gen<0, 10, 1>::type{}, Number<1>{}); #endif }); diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp index 70737ebc6c..56b2f5f0b9 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp @@ -367,7 +367,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw p_out_thread_on_global, {0, 0, 0, 0, 0, 0, 0, 0}, out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 8, 1>::SeqType{}, + arithmetic_sequence_gen<0, 8, 1>::type{}, Number<1>{}); } } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp index ca8412355e..bb9b6cbd07 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp @@ -394,7 +394,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer p_out_thread_on_global, {0, 0, 0, 0, 0, 0, 0, 0}, out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 8, 1>::SeqType{}, + arithmetic_sequence_gen<0, 8, 1>::type{}, Number<1>{}); } } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp index 461db757bd..915193dc40 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp @@ -344,7 +344,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw p_out_thread_on_global, {0, 0, 0, 0, 0, 0, 0, 0}, out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 8, 1>::SeqType{}, + arithmetic_sequence_gen<0, 8, 1>::type{}, Number<1>{}); } } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp index 2b41e2640e..73cec0bb10 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -398,7 +398,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer p_out_thread_on_global, {0, 0, 0, 0, 0, 0, 0, 0}, out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 8, 1>::SeqType{}, + arithmetic_sequence_gen<0, 8, 1>::type{}, Number<1>{}); } } diff --git a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp index a8295d6624..3ed9f3a2b8 100644 --- a/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp @@ -305,8 +305,9 @@ struct ConstantTensorDescriptor { using leaf_tensor = ConstantTensorDescriptor; - return ConstantTensorDescriptor{}; + return ConstantTensorDescriptor{}; } template @@ -347,7 +348,7 @@ struct ConstantTensorDescriptor // folded lengths constexpr auto fold_lengths = - Sequence{}.Append(fold_intervals); + Sequence{}.PushBack(fold_intervals); // folded strides constexpr auto fold_strides = @@ -356,14 +357,14 @@ struct ConstantTensorDescriptor fold_intervals.PushBack(Number<1>{}), math::multiplies{}, Number<1>{}); // left and right - constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{}; + constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::type{}; constexpr auto right = - typename arithmetic_sequence_gen::SeqType{}; + typename arithmetic_sequence_gen::type{}; constexpr auto new_lengths = - GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right)); + GetLengths().Extract(left).PushBack(fold_lengths).PushBack(GetLengths().Extract(right)); constexpr auto new_strides = - GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right)); + GetStrides().Extract(left).PushBack(fold_strides).PushBack(GetStrides().Extract(right)); return ConstantTensorDescriptor{}; } @@ -377,11 +378,11 @@ struct ConstantTensorDescriptor "wrong! should have FirstUnfoldDim <= LastUnfoldDim!"); // left and right - constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::SeqType{}; + constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{}; constexpr auto middle = - typename arithmetic_sequence_gen::SeqType{}; + typename arithmetic_sequence_gen::type{}; constexpr auto right = - typename arithmetic_sequence_gen::SeqType{}; + typename arithmetic_sequence_gen::type{}; // dimensions to be unfolded need to be continuous static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable"); @@ -396,12 +397,12 @@ struct ConstantTensorDescriptor constexpr auto new_lengths = GetLengths() .Extract(left) .PushBack(Number{}) - .Append(GetLengths().Extract(right)); + .PushBack(GetLengths().Extract(right)); constexpr auto new_strides = GetStrides() .Extract(left) .PushBack(Number{}) - .Append(GetStrides().Extract(right)); + .PushBack(GetStrides().Extract(right)); return ConstantTensorDescriptor{}; } diff --git a/composable_kernel/include/utility/Array.hpp b/composable_kernel/include/utility/Array.hpp index f33fa516e2..afe5b392f6 100644 --- a/composable_kernel/include/utility/Array.hpp +++ b/composable_kernel/include/utility/Array.hpp @@ -87,7 +87,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence) template __host__ __device__ constexpr auto make_zero_array() { - constexpr auto zero_sequence = typename uniform_sequence_gen::SeqType{}; + constexpr auto zero_sequence = typename uniform_sequence_gen::type{}; constexpr auto zero_array = sequence2array(zero_sequence); return zero_array; } diff --git a/composable_kernel/include/utility/Sequence.hpp b/composable_kernel/include/utility/Sequence.hpp index 5c566503a6..cf37b95ec2 100644 --- a/composable_kernel/include/utility/Sequence.hpp +++ b/composable_kernel/include/utility/Sequence.hpp @@ -29,12 +29,9 @@ struct Sequence } template - __host__ __device__ constexpr index_t operator[](Number) const + __host__ __device__ constexpr auto operator[](Number) const { - static_assert(I < mSize, "wrong! I too large"); - - const index_t mData[mSize + 1] = {Is..., 0}; - return mData[I]; + return Number{})>{}; } // make sure I is constepxr @@ -69,24 +66,30 @@ struct Sequence return mData[mSize - 1]; } - template - __host__ __device__ static constexpr auto PushFront(Number) - { - return Sequence{}; - } - - template - __host__ __device__ static constexpr auto PushBack(Number) - { - return Sequence{}; - } - __host__ __device__ static constexpr auto PopFront(); __host__ __device__ static constexpr auto PopBack(); template - __host__ __device__ static constexpr auto Append(Sequence) + __host__ __device__ static constexpr auto PushFront(Sequence) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushFront(Number...) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushBack(Sequence) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushBack(Number...) { return Sequence{}; } @@ -105,6 +108,12 @@ struct Sequence template __host__ __device__ static constexpr auto Modify(Number, Number); + + template + __host__ __device__ static constexpr auto Transform(F f) + { + return Sequence{}; + } }; // merge sequence @@ -114,7 +123,7 @@ struct sequence_merge; template struct sequence_merge, Sequence> { - using SeqType = Sequence; + using type = Sequence; }; // arithmetic sqeuence @@ -123,40 +132,29 @@ struct arithmetic_sequence_gen_impl { static constexpr index_t NSizeLeft = NSize / 2; - using SeqType = typename sequence_merge< - typename arithmetic_sequence_gen_impl::SeqType, + using type = typename sequence_merge< + typename arithmetic_sequence_gen_impl::type, typename arithmetic_sequence_gen_impl::SeqType>::SeqType; + Increment>::type>::type; }; template struct arithmetic_sequence_gen_impl { - using SeqType = Sequence; + using type = Sequence; }; template struct arithmetic_sequence_gen_impl { - using SeqType = Sequence<>; + using type = Sequence<>; }; template struct arithmetic_sequence_gen { - using SeqType = - typename arithmetic_sequence_gen_impl::SeqType; -}; - -// transform sequence -template -struct sequence_transform; - -template -struct sequence_transform> -{ - using SeqType = Sequence; + using type = typename arithmetic_sequence_gen_impl::type; }; // uniform sequence @@ -168,9 +166,8 @@ struct uniform_sequence_gen __host__ __device__ constexpr index_t operator()(index_t) const { return I; } }; - using SeqType = typename sequence_transform< - return_constant, - typename arithmetic_sequence_gen<0, NSize, 1>::SeqType>::SeqType; + using type = decltype( + typename arithmetic_sequence_gen<0, NSize, 1>::type{}.Transform(return_constant{})); }; // reverse inclusive scan (with init) sequence @@ -180,34 +177,23 @@ struct sequence_reverse_inclusive_scan; template struct sequence_reverse_inclusive_scan, Reduce, Init> { - using old_scan = - typename sequence_reverse_inclusive_scan, Reduce, Init>::SeqType; + using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); - using SeqType = typename sequence_merge, old_scan>::SeqType; + using type = typename sequence_merge, old_scan>::type; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { - using SeqType = Sequence; + using type = Sequence; }; template struct sequence_reverse_inclusive_scan, Reduce, Init> { - using SeqType = Sequence<>; -}; - -// extract sequence -template -struct sequence_extract; - -template -struct sequence_extract> -{ - using SeqType = Sequence{})...>; + using type = Sequence<>; }; // split sequence @@ -216,11 +202,11 @@ struct sequence_split { static constexpr index_t NSize = Seq{}.GetSize(); - using range0 = typename arithmetic_sequence_gen<0, I, 1>::SeqType; - using range1 = typename arithmetic_sequence_gen::SeqType; + using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; + using range1 = typename arithmetic_sequence_gen::type; - using SeqType0 = typename sequence_extract::SeqType; - using SeqType1 = typename sequence_extract::SeqType; + using SeqType0 = decltype(Seq::Extract(range0{})); + using SeqType1 = decltype(Seq::Extract(range1{})); }; // reverse sequence @@ -230,31 +216,31 @@ struct sequence_reverse static constexpr index_t NSize = Seq{}.GetSize(); using seq_split = sequence_split; - using SeqType = typename sequence_merge< - typename sequence_reverse::SeqType, - typename sequence_reverse::SeqType>::SeqType; + using type = typename sequence_merge< + typename sequence_reverse::type, + typename sequence_reverse::type>::type; }; template struct sequence_reverse> { - using SeqType = Sequence; + using type = Sequence; }; template struct sequence_reverse> { - using SeqType = Sequence; + using type = Sequence; }; template struct is_valid_sequence_map { - static constexpr bool value = true; + static constexpr integral_constant value = integral_constant{}; // TODO: add proper check for is_valid, something like: // static constexpr bool value = - // is_same::SeqType, + // is_same::type, // typename sequence_sort::SortedSeqType>{}; }; @@ -401,7 +387,7 @@ transform_sequences(F f, Sequence, Sequence, Sequence) template __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) { - return typename sequence_reverse_inclusive_scan::SeqType{}; + return typename sequence_reverse_inclusive_scan::type{}; } template @@ -425,7 +411,7 @@ __host__ __device__ constexpr auto Sequence::PopBack() template __host__ __device__ constexpr auto Sequence::Reverse() { - return typename sequence_reverse>::SeqType{}; + return typename sequence_reverse>::type{}; } template @@ -438,7 +424,7 @@ __host__ __device__ constexpr auto Sequence::Modify(Number, Number) constexpr auto seq_left = typename seq_split::SeqType0{}; constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); - return seq_left.PushBack(Number{}).Append(seq_right); + return seq_left.PushBack(Number{}).PushBack(seq_right); } template diff --git a/composable_kernel/include/utility/functional2.hpp b/composable_kernel/include/utility/functional2.hpp index 3820056593..c49341b666 100644 --- a/composable_kernel/include/utility/functional2.hpp +++ b/composable_kernel/include/utility/functional2.hpp @@ -31,7 +31,7 @@ struct static_for static_assert((NEnd - NBegin) % Increment == 0, "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); - static_for_impl::SeqType>{}(f); + static_for_impl::type>{}(f); } }; diff --git a/composable_kernel/include/utility/utility.hpp b/composable_kernel/include/utility/utility.hpp index c744e27ddd..f52f163a7d 100644 --- a/composable_kernel/include/utility/utility.hpp +++ b/composable_kernel/include/utility/utility.hpp @@ -59,9 +59,9 @@ __host__ __device__ constexpr T integer_divide_ceil(T a, T b) } template -__host__ __device__ constexpr T max(T x, T y) +__host__ __device__ constexpr T max(T x) { - return x > y ? x : y; + return x; } template @@ -77,9 +77,9 @@ __host__ __device__ constexpr T max(T x, Ts... xs) } template -__host__ __device__ constexpr T min(T x, T y) +__host__ __device__ constexpr T min(T x) { - return x < y ? x : y; + return x; } template