From a6b95c393b1e258ee45201abff05fcfa9fb6f149 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sat, 18 May 2019 23:21:02 -0500 Subject: [PATCH] rework sequence --- src/include/ConstantTensorDescriptor.hip.hpp | 29 +-- src/include/Sequence.hip.hpp | 243 ++++++++---------- ...3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp | 124 +++++---- 3 files changed, 167 insertions(+), 229 deletions(-) diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index 06223f8cc8..45c779bd25 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -1,30 +1,11 @@ #pragma once #include "common.hip.hpp" -template -__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths) -{ - constexpr index_t previous_stride = PreviousStrides{}.Front(); - constexpr index_t current_length = RemainLengths{}.Back(); - constexpr index_t current_stride = current_length * previous_stride; - - return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number{}), - RemainLengths{}.PopBack()); -} - -template -__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence) -{ - constexpr index_t previous_stride = PreviousStrides{}.Front(); - constexpr index_t current_stride = L1 * previous_stride; - - return PreviousStrides{}.PushFront(Number{}); -} - template __host__ __device__ constexpr auto calculate_default_strides(Lengths) { - return calculate_default_strides_impl(Sequence<1>{}, Lengths{}); + return reverse_inclusive_scan_sequence(Lengths{}.PopFront().PushBack(Number<1>{}), + std::multiplies{}); } // this is ugly, only for 2d @@ -57,7 +38,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence struct ConstantTensorDescriptor { - using Type = ConstantTensorDescriptor; + using Type = ConstantTensorDescriptor; + static constexpr index_t nDim = Lengths::GetSize(); __host__ __device__ constexpr ConstantTensorDescriptor() @@ -193,7 +175,8 @@ struct ConstantTensorDescriptor // folded strides constexpr auto fold_strides = Number{} * - reverse_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies{}); + reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}), + std::multiplies{}); // left and right constexpr auto left = make_increasing_sequence(Number<0>{}, Number{}, Number<1>{}); diff --git a/src/include/Sequence.hip.hpp b/src/include/Sequence.hip.hpp index ae91b2fa29..a7ab687b07 100644 --- a/src/include/Sequence.hip.hpp +++ b/src/include/Sequence.hip.hpp @@ -9,7 +9,8 @@ struct Sequence static constexpr index_t mSize = sizeof...(Is); - const index_t mData[mSize] = {Is...}; + const index_t mData[mSize + 1] = { + Is..., 0}; // the last element is dummy, to prevent compiler complain on empty Sequence __host__ __device__ static constexpr index_t GetSize() { return mSize; } @@ -39,10 +40,7 @@ struct Sequence assert(false); } - __host__ __device__ constexpr auto Reverse() const - { - // not implemented - } + __host__ __device__ constexpr auto Reverse() const; __host__ __device__ constexpr index_t Front() const { return mData[0]; } @@ -73,13 +71,13 @@ struct Sequence template __host__ __device__ constexpr auto Extract(Number...) const { - return Sequence{})...>{}; + return Sequence{})...>{}; } template __host__ __device__ constexpr auto Extract(Sequence) const { - return Sequence{})...>{}; + return Sequence{})...>{}; } }; @@ -89,44 +87,110 @@ struct sequence_merge; template struct sequence_merge, Sequence> { - using Type = Sequence; + using SeqType = Sequence; }; template -struct increasing_sequence_gen +struct increasing_sequence_gen_impl { static constexpr index_t NSizeLeft = NSize / 2; - using Type = - sequence_merge::Type, - typename increasing_sequence_gen::Type>; + using SeqType = typename sequence_merge< + typename increasing_sequence_gen_impl::SeqType, + typename increasing_sequence_gen_impl::SeqType>::SeqType; }; template -struct increasing_sequence_gen +struct increasing_sequence_gen_impl { - using Type = Sequence; + using SeqType = Sequence; }; template -struct increasing_sequence_gen +struct increasing_sequence_gen_impl { - using Type = Sequence<>; + using SeqType = Sequence<>; +}; + +template +struct increasing_sequence_gen +{ + using SeqType = + typename increasing_sequence_gen_impl::SeqType; }; template __host__ __device__ constexpr auto make_increasing_sequence(Number, Number, Number) { - static_assert(IBegin <= IEnd && Increment > 0, "wrong!"); - - constexpr index_t NSize = (IEnd - IBegin) / Increment; - - return increasing_sequence_gen{}; + return typename increasing_sequence_gen::SeqType{}; } +template +struct sequence_reverse_inclusive_scan; + +template +struct sequence_reverse_inclusive_scan, Reduce> +{ + using old_scan = typename sequence_reverse_inclusive_scan, Reduce>::SeqType; + + static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); + + using SeqType = typename sequence_merge, old_scan>::SeqType; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce> +{ + using SeqType = Sequence; +}; + +template +struct sequence_extract; + +template +struct sequence_extract> +{ + using SeqType = Sequence{})...>; +}; + +template +struct sequence_split +{ + static constexpr index_t NSize = Seq{}.GetSize(); + + using range0 = typename increasing_sequence_gen<0, I, 1>::SeqType; + using range1 = typename increasing_sequence_gen::SeqType; + + using SeqType0 = typename sequence_extract::SeqType; + using SeqType1 = typename sequence_extract::SeqType; +}; + +template +struct sequence_reverse +{ + static constexpr index_t NSize = Seq{}.GetSize(); + + using seq_split = sequence_split; + using SeqType = typename sequence_merge< + typename sequence_reverse::SeqType, + typename sequence_reverse::SeqType>::SeqType; +}; + +template +struct sequence_reverse> +{ + using SeqType = Sequence; +}; + +template +struct sequence_reverse> +{ + using SeqType = Sequence; +}; + template __host__ __device__ constexpr auto operator+(Sequence, Sequence) { @@ -179,9 +243,9 @@ __host__ __device__ constexpr auto operator+(Sequence, Number) template __host__ __device__ constexpr auto operator-(Sequence, Number) { +#if 0 // doesn't compile constexpr auto seq_x = Sequence{}; -#if 0 // doesn't compile static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow"); @@ -253,95 +317,12 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence) return Sequence{}; } -#if 0 -// TODO: for some reason, compiler cannot instantiate this template -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) +template +__host__ __device__ constexpr auto sequence_pop_back(Seq) { - static_assert(sizeof...(Is) > 0, "empty Sequence!"); - return Sequence{}; + static_assert(Seq{}.GetSize() > 0, "empty Sequence!"); + return sequence_pop_front(Seq{}.Reverse()).Reverse(); } -#else -// TODO: delete these very ugly mess -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) -{ - return Sequence{}; -} - -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) -{ - return Sequence{}; -} - -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) -{ - return Sequence{}; -} - -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) -{ - return Sequence{}; -} - -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) -{ - return Sequence{}; -} - -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) -{ - return Sequence{}; -} - -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) -{ - return Sequence{}; -} - -template -__host__ __device__ constexpr auto sequence_pop_back(Sequence) -{ - return Sequence{}; -} - -template -__host__ __device__ constexpr auto - sequence_pop_back(Sequence) -{ - return Sequence{}; -} -#endif template __host__ __device__ constexpr auto transform_sequences(F f, Sequence) @@ -399,38 +380,20 @@ __host__ __device__ constexpr index_t return Reduce{}(a, I); } -template -struct scan_sequence_impl +template +__host__ __device__ constexpr auto Sequence::Reverse() const { - template - __host__ __device__ constexpr auto operator()(ScanedSeq, RemainSeq, Reduce) const - { - static_assert(RemainSeq{}.GetSize() == NRemain, - "wrong! RemainSeq and NRemain not consistent!"); - - constexpr index_t a = Reduce{}(ScanedSeq{}.Back(), RemainSeq{}.Front()); - constexpr auto scaned_seq = ScanedSeq{}.PushBack(Number{}); - - static_if<(NRemain > 1)>{}([&](auto fwd) { - return scan_sequence_impl{}( - scaned_seq, RemainSeq{}.PopFront(), fwd(Reduce{})); - }).else_([&](auto fwd) { return fwd(scaned_seq); }); - } -}; - -template -__host__ __device__ constexpr auto scan_sequence(Seq, Reduce) -{ - constexpr auto scaned_seq = Sequence{}; - constexpr auto remain_seq = Seq{}.PopFront(); - - constexpr index_t remain_size = Seq::GetSize() - 1; - - return scan_sequence_impl{}(scaned_seq, remain_seq, Reduce{}); + return typename sequence_reverse>::SeqType{}; } template -__host__ __device__ constexpr auto reverse_scan_sequence(Seq, Reduce) +__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce) { - return scan_seqeunce(Seq{}.Reverse(), Reduce{}).Reverse(); + return typename sequence_reverse_inclusive_scan::SeqType{}; +} + +template +__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce) +{ + return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}).Reverse(); } diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp index 1cf033bef5..c0a67837b6 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp @@ -80,29 +80,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, "wrong! cannot evenly divide work for workgroup "); - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; + constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock); + constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock); + constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); + constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock); - const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const index_t w_block_work_id = itmp / NBlockWork; - const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; + constexpr auto block_work_desc = make_ConstantTensorDescriptor( + Sequence{}); - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; - const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - const index_t n_block_data_begin = n_block_work_id * NPerBlock; + const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); + + const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock; + const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock; + const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock; + const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock; const index_t hi_block_data_begin = ho_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin; // global tensor view - constexpr auto wei_c_k_global_desc = - make_ConstantTensorDescriptor(Sequence{}, Sequence{}); + constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); // LDS tensor view // be careful of alignment @@ -360,13 +357,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - static_if{}([&](auto f_dummy) { // f_dummy do nothing but - // perfect forwarding. - // Using this trick to - // make this lambda a generic lambda, so it won't be compiled until - // instantiated + static_if{}([&](auto fwd) { + // fwd do nothing but perfect forwarding. + // Using this trick to make this lambda a generic lambda, so it won't be compiled until + // being instantiated here static_assert( - (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), + (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), "wrong!"); // output is a 10d tensor @@ -374,38 +370,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn constexpr index_t N1 = NPerBlock / N2; constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC); + (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC); constexpr index_t W1 = WoPerBlock / W2; constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_global_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc) + .Fold(I3, Number{}, Number{}) + .Fold(I2, Number{}, Number{}) + .Fold(I0, Number{}, Number{}); - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc) + .Fold(I3, Number<1>{}, Number{}) + .Fold(I2, Number{}, Number<1>{}) + .Fold(I0, Number<1>{}, Number{}); #if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "a: out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc"); - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - } + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "a: out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc"); + } #endif threadwise_tensor_slice_copy( @@ -419,8 +410,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn n_block_data_begin + n_thread_data_begin), out_10d_thread_desc.GetLengths(), Number{}); - }).else_([&](auto f_dummy) { - static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && + }).else_([&](auto fwd) { + static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && GemmNPerThreadSubC % NPerThread == 0, "wrong!"); @@ -429,33 +420,34 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); + constexpr index_t W1 = WoPerBlock / fwd(W2 * W3); constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K1 = KPerBlock / KPerThread; - constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_global_desc = + fwd(out_k_h_w_n_global_desc) + .Fold(I3, Number{}) + .Fold(I2, Number{}, Number{}, Number{}) + .Fold(I0, Number{}, Number{}); - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); + constexpr auto out_10d_thread_desc = + fwd(out_k_h_w_n_thread_desc) + .Fold(I3, Number{}) + .Fold(I2, Number{}, Number<1>{}, Number{}) + .Fold(I0, Number<1>{}, Number{}); #if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, + "b: out_k_h_w_n_thread_desc"); + print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc"); - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - - for(index_t i = 0; i < 64; ++i) - { - printf("out %f, ", p_out_thread[i]); - } - } + print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, + "b: out_k_h_w_n_global_desc"); + print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc"); + } #endif threadwise_tensor_slice_copy(