From 709f13a6d7c250f4b0777f0f36fb17c5fc08d01a Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 4 Jun 2019 20:00:48 -0500 Subject: [PATCH] use more constexpr --- driver/driver.hip.cpp | 18 +- src/include/Array.hip.hpp | 55 +++++- .../ConstantMergedTensorDescriptor.hip.hpp | 19 +-- src/include/Sequence.hip.hpp | 54 ++++-- .../blockwise_generic_tensor_slice_op.hip.hpp | 13 ++ src/include/common.hip.hpp | 1 + src/include/functional.hip.hpp | 54 +----- src/include/functional2.hip.hpp | 156 ++++++++---------- src/include/functional3.hip.hpp | 109 ++++++++++++ ...4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp | 2 +- ...threadwise_generic_tensor_slice_op.hip.hpp | 12 +- 11 files changed, 310 insertions(+), 183 deletions(-) create mode 100644 src/include/functional3.hip.hpp diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index bb228d31bf..40cd4fdd3f 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -443,7 +443,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 1 // 3x3 filter, 28x28 image constexpr index_t N = 128; constexpr index_t C = 256; @@ -455,7 +455,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 1 +#elif 0 // 1x1 filter, 28x28 image constexpr index_t N = 128; constexpr index_t C = 512; @@ -549,12 +549,24 @@ int main(int argc, char* argv[]) constexpr index_t Y = 1; constexpr index_t X = 1; + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; +#elif 0 + // 1x1 filter, 7x7 image + constexpr index_t N = 128; + constexpr index_t C = 512; + constexpr index_t HI = 7; + constexpr index_t WI = 7; + constexpr index_t K = 2048; + constexpr index_t Y = 1; + constexpr index_t X = 1; + constexpr index_t HPad = 0; constexpr index_t WPad = 0; #elif 0 // 1x1 filter, 73x73 image constexpr index_t N = 128; - constexpr index_t C = 64; + constexpr index_t C = 512; constexpr index_t HI = 73; constexpr index_t WI = 73; constexpr index_t K = 128; diff --git a/src/include/Array.hip.hpp b/src/include/Array.hip.hpp index 9d3d385738..4e1162a8a3 100644 --- a/src/include/Array.hip.hpp +++ b/src/include/Array.hip.hpp @@ -1,6 +1,6 @@ #pragma once #include "Sequence.hip.hpp" -#include "functional.hip.hpp" +#include "functional2.hip.hpp" template struct Array @@ -25,14 +25,17 @@ struct Array template __host__ __device__ constexpr TData Get(Number) const { + static_assert(I < NSize, "wrong!"); + return mData[I]; } template - __host__ __device__ constexpr bool Set(Number, TData x) + __host__ __device__ constexpr void Set(Number, TData x) { + static_assert(I < NSize, "wrong!"); + mData[I] = x; - return true; // for constexpr } __host__ __device__ constexpr auto PushBack(TData x) const @@ -59,6 +62,7 @@ __host__ __device__ constexpr auto sequence2array(Sequence) template __host__ __device__ constexpr auto make_zero_array() { +#if 0 Array a; static_for<0, NSize, 1>{}([&](auto I) { @@ -67,6 +71,11 @@ __host__ __device__ constexpr auto make_zero_array() }); return a; +#else + constexpr auto zero_sequence = typename uniform_sequence_gen::SeqType{}; + constexpr auto zero_array = sequence2array(zero_sequence); + return zero_array; +#endif } template @@ -85,6 +94,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array __host__ __device__ constexpr auto reorder_array_given_old2new(const Array& old_array, Sequence old2new) @@ -100,6 +110,45 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array +struct reorder_array_given_old2new_impl +{ + const Array& old_array_ref; + Array& new_array_ref; + + __host__ + __device__ constexpr reorder_array_given_old2new_impl(const Array& old_array, + Array& new_array) + : old_array_ref(old_array), new_array_ref(new_array) + { + } + + template + __host__ __device__ constexpr void operator()(Number) const + { + TData old_data = old_array_ref.Get(Number{}); + + constexpr index_t INewDim = MapOld2New::Get(Number{}); + + new_array_ref.Set(Number{}, old_data); + } +}; + +template +__host__ __device__ constexpr auto reorder_array_given_old2new(const Array& old_array, + Sequence old2new) +{ + Array new_array; + + static_assert(NSize == sizeof...(IRs), "NSize not consistent"); + + static_for<0, NSize, 1>{}( + reorder_array_given_old2new_impl>(old_array, new_array)); + + return new_array; +} +#endif template __host__ __device__ constexpr auto extract_array(const Array& old_array, ExtractSeq) diff --git a/src/include/ConstantMergedTensorDescriptor.hip.hpp b/src/include/ConstantMergedTensorDescriptor.hip.hpp index b595a2f0a4..8d5ceb3825 100644 --- a/src/include/ConstantMergedTensorDescriptor.hip.hpp +++ b/src/include/ConstantMergedTensorDescriptor.hip.hpp @@ -115,15 +115,13 @@ struct ConstantMergedTensorDescriptor } template - constexpr __host__ __device__ bool operator()(Number) const + __host__ __device__ constexpr void operator()(Number) const { constexpr index_t idim_original = OriginalDimsPartial::Get(Number{}); index_t itmp = original_multi_id_partial_ref.Get(Number{}); original_multi_id_ref.Set(Number{}, itmp); - - return true; } }; @@ -139,7 +137,7 @@ struct ConstantMergedTensorDescriptor } template - constexpr __host__ __device__ bool operator()(Number) const + __host__ __device__ constexpr void operator()(Number) const { constexpr auto original_dims_partial = std::get(std::tuple{}); @@ -152,11 +150,10 @@ struct ConstantMergedTensorDescriptor static_for<0, original_dims_partial.GetSize(), 1>{}( GetOriginalMultiIndexFromMultiIndex_impl1( original_multi_id_partial, original_multi_id_ref)); - - return true; } }; + // return type is Array<...> __host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Array multi_id) { @@ -179,16 +176,6 @@ struct ConstantMergedTensorDescriptor } #endif -#if 0 - // return type is Sequence<...> - template - __host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Sequence) - { - // not implemented - return Sequence<>{}; - } -#endif - __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Array multi_id) { diff --git a/src/include/Sequence.hip.hpp b/src/include/Sequence.hip.hpp index b14b88d4d5..e5edec4ba6 100644 --- a/src/include/Sequence.hip.hpp +++ b/src/include/Sequence.hip.hpp @@ -37,10 +37,11 @@ struct Sequence template __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/) { +#if 0 static_assert(is_same::SortedSeqType, arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value, "wrong! invalid old2new map"); - +#endif constexpr auto map_new2old = typename sequence_map_inverse::SeqMapType{}; return ReorderGivenNew2Old(map_new2old); @@ -99,6 +100,7 @@ struct Sequence __host__ __device__ static constexpr auto Modify(Number, Number); }; +// merge sequence template struct sequence_merge; @@ -108,6 +110,7 @@ struct sequence_merge, Sequence> using SeqType = Sequence; }; +// arithmetic sqeuence template struct arithmetic_sequence_gen_impl { @@ -139,7 +142,31 @@ struct arithmetic_sequence_gen typename arithmetic_sequence_gen_impl::SeqType; }; -// reverse scan with init +// transform sequence +template +struct sequence_transform; + +template +struct sequence_transform> +{ + using SeqType = Sequence; +}; + +// uniform sequence +template +struct uniform_sequence_gen +{ + struct return_constant + { + __host__ __device__ constexpr index_t operator()(index_t) const { return I; } + }; + + using SeqType = typename sequence_transform< + return_constant, + typename arithmetic_sequence_gen<0, NSize, 1>::SeqType>::SeqType; +}; + +// reverse inclusive scan (with init) sequence template struct sequence_reverse_inclusive_scan; @@ -166,22 +193,7 @@ struct sequence_reverse_inclusive_scan, Reduce, Init> using SeqType = Sequence<>; }; -#if 0 -// reverse scan with token -template -struct sequence_reverse_inclusive_token_scan; - -template -struct sequence_reverse_inclusive_token_scan, F, Token> -{ - using old_scan = typename sequence_reverse_inclusive_token_scan, F, Token>::SeqType; - - static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); - - using SeqType = typename sequence_merge, old_scan>::SeqType; -}; -#endif - +// extract sequence template struct sequence_extract; @@ -191,6 +203,7 @@ struct sequence_extract> using SeqType = Sequence{})...>; }; +// split sequence template struct sequence_split { @@ -203,6 +216,7 @@ struct sequence_split using SeqType1 = typename sequence_extract::SeqType; }; +// reverse sequence template struct sequence_reverse { @@ -308,8 +322,10 @@ __host__ __device__ constexpr auto operator-(Sequence seq_x, Sequence{}( [&](auto I) { static_assert(seq_x.Get(I) >= seq_y.Get(I), "wrong! going to undeflow"); }); +#endif return Sequence<(Xs - Ys)...>{}; } @@ -388,10 +404,12 @@ __host__ __device__ constexpr auto operator-(Number, Sequence) { constexpr auto seq_x = Sequence{}; +#if 0 static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) { constexpr auto I = decltype(Iter){}; static_assert(seq_x.Get(I) <= Y, "wrong! going to underflow"); }); +#endif return Sequence<(Y - Xs)...>{}; } diff --git a/src/include/blockwise_generic_tensor_slice_op.hip.hpp b/src/include/blockwise_generic_tensor_slice_op.hip.hpp index d080c362e6..9402164178 100644 --- a/src/include/blockwise_generic_tensor_slice_op.hip.hpp +++ b/src/include/blockwise_generic_tensor_slice_op.hip.hpp @@ -256,6 +256,7 @@ struct BlockwiseGenericTensorSliceCopy_v1 make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); static_ford{}([&](auto repeat_multi_id_) { +#if 0 constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){}); const auto clipboard_data_multi_id_begin = @@ -269,6 +270,18 @@ struct BlockwiseGenericTensorSliceCopy_v1 const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex( dst_data_multi_id_begin); // cannot not constexpr, why? +#else + constexpr auto clipboard_data_multi_id_begin = + repeat_multi_id_ * thread_sub_tensor_lengths; + + constexpr auto dst_data_multi_id_begin = repeat_multi_id_ * data_per_cluster_per_dims; + + constexpr index_t clipboard_offset = + thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); + + constexpr index_t dst_offset = + DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin); +#endif threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc, p_clipboard + clipboard_offset, diff --git a/src/include/common.hip.hpp b/src/include/common.hip.hpp index 577e44ac97..bc8df3bc5a 100644 --- a/src/include/common.hip.hpp +++ b/src/include/common.hip.hpp @@ -5,6 +5,7 @@ #include "Array.hip.hpp" #include "functional.hip.hpp" #include "functional2.hip.hpp" +#include "functional3.hip.hpp" #if USE_AMD_INLINE_ASM #include "amd_inline_asm.hip.hpp" diff --git a/src/include/functional.hip.hpp b/src/include/functional.hip.hpp index f287e5c08e..ab811d292c 100644 --- a/src/include/functional.hip.hpp +++ b/src/include/functional.hip.hpp @@ -1,5 +1,6 @@ #pragma once #include "integral_constant.hip.hpp" +#include "Sequence.hip.hpp" struct forwarder { @@ -10,6 +11,14 @@ struct forwarder } }; +struct swallow +{ + template + __host__ __device__ constexpr swallow(Ts&&... ts) + { + } +}; + #if 0 template __host__ __device__ constexpr auto unpacker(F f) @@ -72,51 +81,6 @@ struct static_if return Type{}; } }; -template -struct static_for_impl -{ - template - constexpr __host__ __device__ void operator()(F f) const - { - static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0"); - static_assert(Increment <= Remaining, "will go out-of-range"); - - f(Number{}); - static_for_impl{}(f); - } -}; - -template -struct static_for_impl -{ - template - constexpr __host__ __device__ void operator()(F) const - { - // no work left, just return - return; - } -}; - -// F signature: F(Number) -template -struct static_for -{ - template - constexpr __host__ __device__ void operator()(F f) const - { - static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd"); - - static_assert((NEnd - NBegin) % Increment == 0, - "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); - -#if 0 - static_if<(NBegin < NEnd)>{}( - [&](auto fwd) { static_for_impl{}(f); }); -#else - static_for_impl{}(f); -#endif - } -}; template struct static_const_reduce_n diff --git a/src/include/functional2.hip.hpp b/src/include/functional2.hip.hpp index b71c56e719..e307f31f60 100644 --- a/src/include/functional2.hip.hpp +++ b/src/include/functional2.hip.hpp @@ -1,106 +1,80 @@ #pragma once +#include "functional.hip.hpp" #include "Sequence.hip.hpp" -// RemainLengths: Sequence<...> -template -struct static_ford_impl +#if 0 +template +struct static_for_impl { - // F signature: F(Sequence<...> multi_id) - // CurrentMultiIndex: Sequence<...> - template - __host__ __device__ void operator()(F f, CurrentMultiIndex) const - { - static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); - - static_for<0, RemainLengths::Front(), 1>{}([=](auto I) { - static_ford_impl{}(f, - CurrentMultiIndex::PushBack(I)); - }); - } -}; - -template <> -struct static_ford_impl> -{ - // F signature: F(Sequence<...> multi_id) - // CurrentMultiIndex: Sequence<...> - template - __host__ __device__ void operator()(F f, CurrentMultiIndex) const - { - f(CurrentMultiIndex{}); - } -}; - -// Lengths is Sequence<...> -template -struct static_ford -{ - // F signature: F(Sequence<...> multi_id) template - __host__ __device__ void operator()(F f) const + constexpr __host__ __device__ void operator()(F f) const { - static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0"); + static_assert(Increment <= Remaining, "will go out-of-range"); - static_ford_impl{}(f, Sequence<>{}); + f(Number{}); + static_for_impl{}(f); } }; -template -struct ford_impl +template +struct static_for_impl { - // F signature: F(Array<...> multi_id) - // CurrentMultiIndex: Array<...> - // RemainLengths: Sequence<...> - template - __host__ __device__ void - operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const - { - static_assert(RemainLengths::GetSize() == RemainDim, "wrong!"); - static_assert(RemainDim > 1, "wrong!"); - - constexpr auto next_length = RemainLengths{}.Front(); - - for(index_t i = 0; i < next_length; ++i) - { - ford_impl{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront()); - } - } -}; - -template <> -struct ford_impl<1> -{ - // F signature: F(Array<...> multi_id) - // CurrentMultiIndex: Array<...> - // RemainLengths: Sequence<...> - template - __host__ __device__ void - operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const - { - static_assert(RemainLengths::GetSize() == 1, "wrong!"); - - constexpr index_t last_length = RemainLengths{}.Front(); - - for(index_t i = 0; i < last_length; ++i) - { - f(current_multi_id.PushBack(i)); - } - } -}; - -// Lengths is Sequence<...> -template -struct ford -{ - // F signature: F(Array<...> multi_id) template - __host__ __device__ void operator()(F f) const + constexpr __host__ __device__ void operator()(F) const { - constexpr index_t first_length = Lengths{}.Front(); - - for(index_t i = 0; i < first_length; ++i) - { - ford_impl{}(f, Array{i}, Lengths{}.PopFront()); - } + // no work left, just return + return; } }; + +// F signature: F(Number) +template +struct static_for +{ + template + constexpr __host__ __device__ void operator()(F f) const + { + static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd"); + + static_assert((NEnd - NBegin) % Increment == 0, + "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + +#if 0 + static_if<(NBegin < NEnd)>{}( + [&](auto fwd) { static_for_impl{}(f); }); +#else + static_for_impl{}(f); +#endif + } +}; +#else +template +struct static_for_impl; + +template +struct static_for_impl> +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + swallow{(f(Number{}), 0)...}; + } +}; + +// F signature: F(Number) +template +struct static_for +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd"); + + static_assert((NEnd - NBegin) % Increment == 0, + "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + + static_for_impl::SeqType>{}(f); + } +}; +#endif diff --git a/src/include/functional3.hip.hpp b/src/include/functional3.hip.hpp new file mode 100644 index 0000000000..78b95200c5 --- /dev/null +++ b/src/include/functional3.hip.hpp @@ -0,0 +1,109 @@ +#pragma once +#include "functional.hip.hpp" +#include "functional2.hip.hpp" +#include "Sequence.hip.hpp" +#include "Array.hip.hpp" + +// RemainLengths: Sequence<...> +template +struct static_ford_impl +{ + // F signature: F(Sequence<...> multi_id) + // CurrentMultiIndex: Sequence<...> + template + __host__ __device__ void operator()(F f, CurrentMultiIndex) const + { + static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); + + static_for<0, RemainLengths::Front(), 1>{}([=](auto I) { + static_ford_impl{}(f, + CurrentMultiIndex::PushBack(I)); + }); + } +}; + +template <> +struct static_ford_impl> +{ + // F signature: F(Sequence<...> multi_id) + // CurrentMultiIndex: Sequence<...> + template + __host__ __device__ void operator()(F f, CurrentMultiIndex) const + { + f(CurrentMultiIndex{}); + } +}; + +// Lengths is Sequence<...> +template +struct static_ford +{ + // F signature: F(Sequence<...> multi_id) + template + __host__ __device__ void operator()(F f) const + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + + static_ford_impl{}(f, Sequence<>{}); + } +}; + +template +struct ford_impl +{ + // F signature: F(Array<...> multi_id) + // CurrentMultiIndex: Array<...> + // RemainLengths: Sequence<...> + template + __host__ __device__ void + operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const + { + static_assert(RemainLengths::GetSize() == RemainDim, "wrong!"); + static_assert(RemainDim > 1, "wrong!"); + + constexpr auto next_length = RemainLengths{}.Front(); + + for(index_t i = 0; i < next_length; ++i) + { + ford_impl{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront()); + } + } +}; + +template <> +struct ford_impl<1> +{ + // F signature: F(Array<...> multi_id) + // CurrentMultiIndex: Array<...> + // RemainLengths: Sequence<...> + template + __host__ __device__ void + operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const + { + static_assert(RemainLengths::GetSize() == 1, "wrong!"); + + constexpr index_t last_length = RemainLengths{}.Front(); + + for(index_t i = 0; i < last_length; ++i) + { + f(current_multi_id.PushBack(i)); + } + } +}; + +// Lengths is Sequence<...> +template +struct ford +{ + // F signature: F(Array<...> multi_id) + template + __host__ __device__ void operator()(F f) const + { + constexpr index_t first_length = Lengths{}.Front(); + + for(index_t i = 0; i < first_length; ++i) + { + ford_impl{}(f, Array{i}, Lengths{}.PopFront()); + } + } +}; diff --git a/src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp index 80d933b21a..3cb67d4058 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp @@ -246,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw // choose GEMM implementation here const auto run_blockwise_gemm = [&](auto... Xs) { -#if 0 +#if 1 return blockwise_gemm.Run(Xs...); #else return blockwise_gemm.Run_asm(Xs...); diff --git a/src/include/threadwise_generic_tensor_slice_op.hip.hpp b/src/include/threadwise_generic_tensor_slice_op.hip.hpp index d40f51b6b2..3803ab23ac 100644 --- a/src/include/threadwise_generic_tensor_slice_op.hip.hpp +++ b/src/include/threadwise_generic_tensor_slice_op.hip.hpp @@ -77,14 +77,14 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( *reinterpret_cast(&p_src[src_index]); }); #else - static_ford{}([&](auto access_multi_id_) { - const auto access_multi_id = sequence2array(access_multi_id_); + static_ford{}([&](auto access_multi_id) { + constexpr index_t itmp = access_multi_id.Back() * DataPerAccess; - auto data_multi_id_in_access_order = access_multi_id; - data_multi_id_in_access_order[nDim - 1] = access_multi_id[nDim - 1] * DataPerAccess; + constexpr auto data_multi_id_in_access_order = + access_multi_id.Modify(Number{}, Number{}); - const auto data_multi_id = - reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{}); + constexpr auto data_multi_id = reorder_array_given_old2new( + sequence2array(data_multi_id_in_access_order), DimAccessOrder{}); const index_t src_index = SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);