From ca42e9101d7eb1930dad87407dcf4d36693ecf65 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 10 Sep 2019 01:53:49 -0500 Subject: [PATCH] adding merge transform --- ...plicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp | 58 ++- .../include/tensor_description/dimension.hpp | 4 +- .../multi_index_transform.hpp | 117 +++-- .../tensor_description/tensor_descriptor.hpp | 196 ++++---- .../tensor_descriptor_helper.hpp | 59 ++- .../include/utility/{Array.hpp => array.hpp} | 7 +- .../include/utility/array_helper.hpp | 6 +- .../include/utility/common_header.hpp | 4 +- .../include/utility/functional.hpp | 2 +- .../include/utility/functional2.hpp | 2 +- .../include/utility/functional3.hpp | 4 +- .../include/utility/functional4.hpp | 4 +- composable_kernel/include/utility/math.hpp | 1 + .../utility/{Sequence.hpp => sequence.hpp} | 430 ++++++++++++------ .../include/utility/sequence_helper.hpp | 4 +- composable_kernel/include/utility/tuple.hpp | 14 +- composable_kernel/include/utility/type.hpp | 4 +- driver/src/driver.cpp | 4 +- 18 files changed, 609 insertions(+), 311 deletions(-) rename composable_kernel/include/utility/{Array.hpp => array.hpp} (99%) rename composable_kernel/include/utility/{Sequence.hpp => sequence.hpp} (52%) diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp index ad53cd89c3..2c5e1e087b 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp @@ -528,34 +528,64 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded #elif 1 // create a native tensor descriptor constexpr auto in_c_h_w_n_global_desc = - make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); + make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1); constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2); constexpr index_t N = in_c_h_w_n_global_desc.GetLength(I3); - constexpr auto pad_h_w = Pad, LowerPads, UpperPads>{}; - constexpr auto pass_c = PassThrough{}; - constexpr auto pass_n = PassThrough{}; + // transformation: {c, h, w, n} --> {n, c, hp, wp} + // {h, w} --> {hp, wp}, {c} --> {c}, {n} --> {n} + constexpr auto in_n_c_hp_wp_global_desc = transform_tensor_descriptor( + in_c_h_w_n_global_desc, + make_tuple( + Pad, LowerPads, UpperPads>{}, PassThrough{}, PassThrough{}), + make_tuple(Sequence<1, 2>{}, Sequence<0>{}, Sequence<3>{}), + make_tuple(Sequence<2, 3>{}, Sequence<1>{}, Sequence<0>{})); - constexpr auto trans = make_tuple(pass_c, pad_h_w, pass_n); - constexpr auto lower_dim_groups = - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}); - constexpr auto upper_dim_groups = - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}); - - constexpr auto in_c_h_w_n_padded_global_desc = transform_tensor_descriptor( - in_c_h_w_n_global_desc, trans, lower_dim_groups, upper_dim_groups); +#if 1 + // transformation: {n, c, hp, wp} --> {c, b} + // {n, hp, wp} --> {b}, {c} --> {c} + constexpr auto in_c_b_global_desc = transform_tensor_descriptor( + in_n_c_hp_wp_global_desc, + make_tuple(Merge{}, + PassThrough{}), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); +#endif +#if 1 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { + // 0 print_tensor_descriptor("in_c_h_w_n_global_desc", in_c_h_w_n_global_desc); - printf("offset: %lu\n", in_c_h_w_n_global_desc.GetOffset({1, 2, 3, 4})); + // 1 + print_tensor_descriptor("in_n_c_hp_wp_global_desc", in_n_c_hp_wp_global_desc); - printf("padded offset: %lu\n", in_c_h_w_n_padded_global_desc.GetOffset({1, 4, 5, 4})); + // 2 + print_tensor_descriptor("in_c_b_global_desc", in_c_b_global_desc); + + constexpr auto idx2 = MultiIndex<2>{1, 4 * (16 * 16) + 5 * 16 + 6}; + auto idx1 = in_c_b_global_desc.CalculateLowerIndex(idx2); + auto idx0 = in_c_b_global_desc.GetLowerTensorDescriptor().CalculateLowerIndex(idx1); + + print_array("idx2: ", idx2); + print_array("idx1: ", idx1); + print_array("idx0: ", idx0); + + printf("in_c_b_global_desc offset: %lu\n", in_c_b_global_desc.CalculateOffset(idx2)); } +#else + { + index_t c = static_cast(threadIdx.x); + index_t h = static_cast(threadIdx.y); + index_t w = static_cast(threadIdx.z); + + p_out_global[0] = in_n_c_h_w_padded_global_desc.CalculateOffset({1, c, h, w}); + } +#endif #endif } #endif diff --git a/composable_kernel/include/tensor_description/dimension.hpp b/composable_kernel/include/tensor_description/dimension.hpp index 1cf8dc4aba..7cb94366a1 100644 --- a/composable_kernel/include/tensor_description/dimension.hpp +++ b/composable_kernel/include/tensor_description/dimension.hpp @@ -18,9 +18,9 @@ struct NativeDimension __host__ __device__ static constexpr auto GetStride() { return Number{}; } - __host__ __device__ static constexpr index_t GetOffset(index_t i) { return i * Stride; } + __host__ __device__ static constexpr index_t CalculateOffset(index_t i) { return i * Stride; } - __host__ __device__ static constexpr index_t GetOffsetDiff(index_t i_diff) + __host__ __device__ static constexpr index_t CalculateOffsetDiff(index_t i_diff) { return i_diff * Stride; } diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index c21f055390..a89eb2dfb8 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -22,9 +22,12 @@ struct PassThrough __host__ __device__ static constexpr auto GetUpperLengths() { return Sequence{}; } - __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) { return idx_up; } + __host__ __device__ static constexpr auto CalculateLowerIndex(UpperIndex idx_up) + { + return idx_up; + } - __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) + __host__ __device__ static constexpr auto CalculateLowerIndexDiff(UpperIndex idx_up_diff) { return idx_up_diff; } @@ -36,7 +39,7 @@ struct PassThrough template struct Pad { - static constexpr index_t nDim = LowLengths::GetSize(); + static constexpr index_t nDim = LowLengths::Size(); using LowerIndex = MultiIndex; using UpperIndex = MultiIndex; @@ -52,12 +55,12 @@ struct Pad return GetLowerLengths() + LeftPads{} + RightPads{}; } - __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) + __host__ __device__ static constexpr auto CalculateLowerIndex(UpperIndex idx_up) { return idx_up - LeftPads{}; } - __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) + __host__ __device__ static constexpr auto CalculateLowerIndexDiff(UpperIndex idx_up_diff) { return idx_up_diff; } @@ -65,21 +68,20 @@ struct Pad __host__ __device__ static constexpr bool IsLinearTransform() { return true; } }; -#if 0 // LowLengths: Sequence<...> template struct Merge { - static constexpr index_t nDimLow = LowLengths::GetSize(); + static constexpr index_t nDimLow = LowLengths::Size(); static constexpr index_t nDimUp = 1; using LowerIndex = MultiIndex; using UpperIndex = MultiIndex; - __host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number{}}; - __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } + __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } + __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } __host__ __device__ static constexpr auto GetUpperLengths() @@ -88,18 +90,56 @@ struct Merge GetLowerLengths(), math::multiplies{}, Number<1>{})>{}; } - __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) + // emulate constexpr lambda + template + struct lambda_CalculateLowerIndex + { + index_t& itmp; + LowerIndex& idx_low; + + __host__ __device__ explicit constexpr lambda_CalculateLowerIndex(index_t& itmp_, + LowerIndex& idx_low_) + : itmp(itmp_), idx_low(idx_low_) + { + } + + template + __host__ __device__ constexpr void operator()(IDim idim) const + { + constexpr index_t stride = PseudoLowStrides::At(idim); + idx_low(idim) = itmp / stride; + itmp -= idx_low[idim] * stride; + } + }; + + __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) { LowerIndex idx_low; - // not implemeneted + index_t itmp = idx_up[0]; + + constexpr auto pseudo_low_strides = + reverse_inclusive_scan_sequence( + GetLowerLengths().PopFront(), math::multiplies{}, Number<1>{}) + .PushBack(Number<1>{}); + +// calculate index in each of the dimensions in the order of their dimension +#if 1 + static_for<0, nDimLow - 1, 1>{}( + lambda_CalculateLowerIndex(itmp, idx_low)); + + idx_low(nDimLow - 1) = itmp / pseudo_low_strides[nDimLow - 1]; +#else + static_for<0, nDimLow, 1>{}( + lambda_CalculateLowerIndex(itmp, idx_low)); +#endif return idx_low; } // idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date - __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff, - LowerIndex idx_low_old) + __host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, + const LowerIndex& idx_low_old) { LowerIndex idx_low_diff; @@ -110,49 +150,48 @@ struct Merge __host__ __device__ static constexpr bool IsLinearTransform() { return false; } }; -#endif // UpLengths: Sequence<...> -template +template struct Unmerge { static constexpr index_t nDimLow = 1; - static constexpr index_t nDimUp = UpLengths::GetSize(); + static constexpr index_t nDimUp = UpLengths::Size(); - using UpperIndex = MultiIndex; using LowerIndex = MultiIndex; - - __host__ __device__ constexpr Unmerge() - { - static_assert(LowLength == accumulate_on_sequence( - UpLengths{}, math::multiplies{}, Number<1>{}), - "wrong! UpLengths need to be "); - } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } + using UpperIndex = MultiIndex; __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } - __host__ __device__ static constexpr auto GetLowerLengths() { return Sequence{}; } + __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } + + __host__ __device__ static constexpr auto GetLowerLengths() + { + constexpr index_t low_length = + accumulate_on_sequence(UpLengths{}, math::multiplies{}, Number<1>{}); + + return Sequence{}; + } __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } - __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) + __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) { - constexpr auto scans = typename sequence_reverse_inclusive_scan, - 1>::type{}; - LowerIndex idx_low{0}; - static_for<0, nDimUp, 1>{}([&](auto idim) { idx_low(0) += idx_up[idim] * scans[idim]; }); + constexpr auto pseudo_up_strides = + typename sequence_reverse_inclusive_scan, 1>:: + type{}; + + static_for<0, nDimUp, 1>{}( + [&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; }); return idx_low; } - __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) + __host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff) { - return GetLowerIndex(idx_up_diff); + return CalculateLowerIndex(idx_up_diff); } __host__ __device__ static constexpr bool IsLinearTransform() { return true; } @@ -165,12 +204,12 @@ template struct Embed { static constexpr index_t nDimLow = 1; - static constexpr index_t nDimUp = UpLengths::GetSize(); + static constexpr index_t nDimUp = UpLengths::Size(); using LowerIndex = MultiIndex; using UpperIndex = MultiIndex; - __host__ __device__ constexpr Embed() + __host__ __device__ explicit constexpr Embed() { static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1, "wrong! # of dimensions not consistent"); @@ -191,7 +230,7 @@ struct Embed __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } - __host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) + __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) { LowerIndex idx_low(Coefficients{}[nDimUp]); @@ -201,7 +240,7 @@ struct Embed return idx_low; } - __host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) + __host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff) { LowerIndex idx_low_diff{0}; diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index 952b378151..8ecc2cdbcf 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -18,34 +18,6 @@ struct NativeTensorDescriptor __host__ __device__ static constexpr auto GetNumOfDimension() { return Number{}; } - struct lambda_GetLength - { - template - __host__ __device__ constexpr auto operator()(IDim) const - { - return GetLength(IDim{}); - } - }; - - __host__ __device__ static constexpr auto GetLengths() - { - return typename sequence_gen::type{}; - } - - struct lambda_GetStride - { - template - __host__ __device__ constexpr auto operator()(IDim) const - { - return GetStride(IDim{}); - } - }; - - __host__ __device__ static constexpr auto GetStrides() - { - return typename sequence_gen::type{}; - } - template __host__ __device__ static constexpr auto GetLength(Number) { @@ -58,7 +30,41 @@ struct NativeTensorDescriptor return mDimensions.At(Number{}).GetStride(); } - __host__ __device__ static constexpr index_t GetOffset(const Index& idx) + template + __host__ __device__ static constexpr auto GetLengths(Sequence) + { + return Sequence{})...>{}; + } + + template + __host__ __device__ static constexpr auto GetStrides(Sequence) + { + return Sequence{})...>{}; + } + + template + __host__ __device__ static constexpr auto GetLengths(Number, Number...) + { + return GetLengths(Sequence{}); + } + + template + __host__ __device__ static constexpr auto GetStrides(Number, Number...) + { + return GetStrides(Sequence{}); + } + + __host__ __device__ static constexpr auto GetLengths() + { + return GetLengths(typename arithmetic_sequence_gen<0, nDim, 1>::type{}); + } + + __host__ __device__ static constexpr auto GetStrides() + { + return GetStrides(typename arithmetic_sequence_gen<0, nDim, 1>::type{}); + } + + __host__ __device__ static constexpr index_t CalculateOffset(const Index& idx) { index_t offset = 0; @@ -67,7 +73,7 @@ struct NativeTensorDescriptor return offset; } - __host__ __device__ static constexpr index_t GetOffsetDiff(const Index& idx_diff) + __host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff) { index_t offset_diff = 0; @@ -161,8 +167,10 @@ struct TransformedTensorDescriptor // UpDimensionIds should include all up-dimensions // TODO: sanity check: while a up-dimension could be associated with multille - // transformation, - // a low-dimension should be associated with only one transformation + // transformation, a low-dimension should be associated with only one transformation + + // TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths + // of lower-tensor-descriptor } __host__ __device__ static constexpr auto GetNumOfDimension() @@ -170,49 +178,78 @@ struct TransformedTensorDescriptor return GetNumOfUpperDimension(); } -#if 0 - __host__ __device__ static constexpr auto GetUpperLengths() - { - struct lambda_get_upper_lengths - { - template - __host__ __device__ constexpr auto operator()(Transform tran) const - { - return tran.GetUpperLengths(); - } - }; - - constexpr auto tuple_of_upper_lengths = - transform_tuple(Transforms, lambda_get_upper_lengths{}); - - constexpr auto all_upper_lengths = merge_tuple_of_sequences(tuple_of_upper_lengths); - - constexpr auto all_upper_dimension_ids = merge_tuple_of_sequences(UpDimensionIds{}); - - // TODO: sanity-check all_upper_dimension_ids contain all upper-dimensions - // TODO: sanity-check all_upper_lengths have no conflicting upper-length - - using sort_dimension_ids = - sequence_unique_sort>; - - constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type; - constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type; - - constexpr auto sorted_upper_lengths = - sequence_element_pick(all_upper_lengths, sorted2unsorted_map); - - return sorted_upper_lengths; - } - - __host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); } -#endif - __host__ __device__ static constexpr auto GetLowerTensorDescriptor() { return LowTensorDescriptor{}; } - __host__ __device__ static constexpr LowerIndex GetLowerIndex(const UpperIndex& idx_up) + __host__ __device__ static constexpr auto GetLowerLengths() + { + return GetLowerTensorDescriptor().GetLengths(); + } + + struct lambda_GetUpperLengths + { + template + __host__ __device__ constexpr auto operator()(const Transform& tran) const + { + return tran.GetUpperLengths(); + } + }; + + __host__ __device__ static constexpr auto GetUpperLengths() + { + constexpr auto tuple_of_up_lengths = + transform_tuple(lambda_GetUpperLengths{}, Transforms{}); + + constexpr auto mingled_up_lengths = unpack(lambda_merge_sequences{}, tuple_of_up_lengths); + + constexpr auto mingled_up_dimension_ids = + unpack(lambda_merge_sequences{}, UpDimensionIds{}); + + // TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions + // TODO: sanity-check mingled_up_lengths have no conflicting upper-length + + // sort by upper-dimension-ids + using sort_up_dimension_ids = sequence_unique_sort, + math::equal>; + + // sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1> + static_assert(is_same::type>{}, + "wrong! UpDimensionIds is not configured correctly"); + + constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{}; + + constexpr auto sorted_up_lengths = + pick_sequence_elements(mingled_up_lengths, sorted2unsorted_map); + + return sorted_up_lengths; + } + + __host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); } + + template + __host__ __device__ static constexpr auto GetLength(Number) + { + return GetLengths()[IDim]; + } + + template + __host__ __device__ static constexpr auto GetLengths(Sequence) + { + return Sequence{})...>{}; + } + + template + __host__ __device__ static constexpr auto GetLengths(Number, Number...) + { + return GetLengths(Sequence{}); + } + + // TODO: right now return value is constexpr because use of non-constepxr lambda + __host__ __device__ static constexpr LowerIndex CalculateLowerIndex(const UpperIndex& idx_up) { LowerIndex idx_low; @@ -225,14 +262,15 @@ struct TransformedTensorDescriptor // this assume each lower (single) index is only assocaited with one transformation, // which is required for index transformation, and has been checked during constructor // of TransformedTensorDescriptor - idx_low_part = tran.GetLowerIndex(to_array(idx_up_part)); + idx_low_part = tran.CalculateLowerIndex(to_array(idx_up_part)); }); return idx_low; } - __host__ __device__ static constexpr LowerIndex GetLowerIndexDiff(const UpperIndex& idx_up_diff, - const LowerIndex& idx_low_old) + // TODO: right now return value is constexpr because use of non-constepxr lambda + __host__ __device__ static constexpr LowerIndex + CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, const LowerIndex& idx_low_old) { LowerIndex idx_low_diff; @@ -250,15 +288,15 @@ struct TransformedTensorDescriptor // this assume each lower (single) index is associated with only one transformation, // which is required for index transformation, and has been checked during constructor // of TransformedTensorDescriptor - idx_low_diff_part = tran.GetLowerIndex(idx_up_diff_part, idx_low_old_part); + idx_low_diff_part = tran.CalculateLowerIndex(idx_up_diff_part, idx_low_old_part); }); return idx_low_diff; } - __host__ __device__ static constexpr index_t GetOffset(const UpperIndex& idx_up) + __host__ __device__ static constexpr index_t CalculateOffset(const UpperIndex& idx_up) { - return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up)); + return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up)); } #if 0 @@ -286,14 +324,14 @@ struct TransformedTensorDescriptor }; template -__host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence, - Sequence) +__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence, + Sequence) { return NativeTensorDescriptor...>{}; } template -__host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths) +__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths) { constexpr auto strides = reverse_inclusive_scan_sequence( Lengths::PopFront(), math::multiplies{}, Number<1>{}) diff --git a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp index 04009a740c..603cd1e3d6 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp @@ -7,12 +7,19 @@ namespace ck { template -__host__ __device__ void print_tensor_descriptor(const char* s, - NativeTensorDescriptor desc) +__host__ __device__ void +print_tensor_descriptor(const char* s, const NativeTensorDescriptor& desc) { print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides()); } +template +__host__ __device__ void print_tensor_descriptor(const char* s, + const TransformedTensorDescriptor& desc) +{ + print_tensor_descriptor_impl(s, desc.GetLengths()); +} + template __host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence, Sequence) @@ -113,5 +120,53 @@ print_tensor_descriptor_impl(const char* s, Sequence, Sequence +__host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence) +{ + constexpr index_t nDim = sizeof...(Lengths); + + static_assert(nDim > 0 && nDim <= 12, "wrong!"); + + static_if{}([&](auto) { printf("%s dim %u, lengths {%u}\n", s, nDim, Lengths...); }); + + static_if{}( + [&](auto) { printf("%s dim %u, lengths {%u %u}\n", s, nDim, Lengths...); }); + + static_if{}( + [&](auto) { printf("%s dim %u, lengths {%u %u %u}\n", s, nDim, Lengths...); }); + + static_if{}( + [&](auto) { printf("%s dim %u, lengths {%u %u %u %u}\n", s, nDim, Lengths...); }); + + static_if{}( + [&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}\n", s, nDim, Lengths...); }); + + static_if{}( + [&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, \n", s, nDim, Lengths...); }); + + static_if{}( + [&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}\n", s, nDim, Lengths...); }); + + static_if{}([&](auto) { + printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); + }); + + static_if{}([&](auto) { + printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); + }); + + static_if{}([&](auto) { + printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); + }); + + static_if{}([&](auto) { + printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); + }); + + static_if{}([&](auto) { + printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); + }); +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/Array.hpp b/composable_kernel/include/utility/array.hpp similarity index 99% rename from composable_kernel/include/utility/Array.hpp rename to composable_kernel/include/utility/array.hpp index 1cc8d4d0d6..52e92da3f1 100644 --- a/composable_kernel/include/utility/Array.hpp +++ b/composable_kernel/include/utility/array.hpp @@ -1,7 +1,7 @@ #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP -#include "Sequence.hpp" +#include "sequence.hpp" #include "functional2.hpp" namespace ck { @@ -17,7 +17,7 @@ struct Array __host__ __device__ explicit constexpr Array() {} template - __host__ __device__ explicit constexpr Array(X x, Xs... xs) + __host__ __device__ constexpr Array(X x, Xs... xs) : mData{static_cast(x), static_cast(xs)...} { static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size"); @@ -176,7 +176,6 @@ __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks) return ArrayElementPicker(a); } -#if 1 template __host__ __device__ constexpr auto to_array(const T& x) { @@ -186,8 +185,8 @@ __host__ __device__ constexpr auto to_array(const T& x) return y; } -#endif +// TODO: remove this template __host__ __device__ constexpr auto sequence2array(Sequence) { diff --git a/composable_kernel/include/utility/array_helper.hpp b/composable_kernel/include/utility/array_helper.hpp index 55f022be7d..a3536309fa 100644 --- a/composable_kernel/include/utility/array_helper.hpp +++ b/composable_kernel/include/utility/array_helper.hpp @@ -1,12 +1,12 @@ #ifndef CK_ARRAY_HELPER_HPP #define CK_ARRAY_HELPER_HPP -#include "Array.hpp" +#include "array.hpp" namespace ck { template -__host__ __device__ void print_Array(const char* s, Array a) +__host__ __device__ void print_array(const char* s, Array a) { constexpr index_t nsize = a.GetSize(); @@ -90,4 +90,4 @@ __host__ __device__ void print_Array(const char* s, Array a) } } // namespace ck -#endif \ No newline at end of file +#endif diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 902e78f25c..ed581c95c0 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -9,9 +9,9 @@ #include "tuple.hpp" #include "math.hpp" #include "vector_type.hpp" -#include "Sequence.hpp" +#include "sequence.hpp" #include "sequence_helper.hpp" -#include "Array.hpp" +#include "array.hpp" #include "array_helper.hpp" #include "functional.hpp" #include "functional2.hpp" diff --git a/composable_kernel/include/utility/functional.hpp b/composable_kernel/include/utility/functional.hpp index a22f8b2435..3dd469c8bc 100644 --- a/composable_kernel/include/utility/functional.hpp +++ b/composable_kernel/include/utility/functional.hpp @@ -2,7 +2,7 @@ #define CK_FUNCTIONAL_HPP #include "integral_constant.hpp" -#include "Sequence.hpp" +#include "sequence.hpp" #include "type.hpp" namespace ck { diff --git a/composable_kernel/include/utility/functional2.hpp b/composable_kernel/include/utility/functional2.hpp index df10ca1f25..68706a2973 100644 --- a/composable_kernel/include/utility/functional2.hpp +++ b/composable_kernel/include/utility/functional2.hpp @@ -2,7 +2,7 @@ #define CK_FUNCTIONAL2_HPP #include "functional.hpp" -#include "Sequence.hpp" +#include "sequence.hpp" namespace ck { diff --git a/composable_kernel/include/utility/functional3.hpp b/composable_kernel/include/utility/functional3.hpp index 540938fd10..48a0933793 100644 --- a/composable_kernel/include/utility/functional3.hpp +++ b/composable_kernel/include/utility/functional3.hpp @@ -3,8 +3,8 @@ #include "functional.hpp" #include "functional2.hpp" -#include "Sequence.hpp" -#include "Array.hpp" +#include "sequence.hpp" +#include "array.hpp" namespace ck { diff --git a/composable_kernel/include/utility/functional4.hpp b/composable_kernel/include/utility/functional4.hpp index 2cbc94ea8b..70475ced4a 100644 --- a/composable_kernel/include/utility/functional4.hpp +++ b/composable_kernel/include/utility/functional4.hpp @@ -1,9 +1,9 @@ #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP -#include "Sequence.hpp" +#include "sequence.hpp" #include "tuple.hpp" -#include "Array.hpp" +#include "array.hpp" namespace ck { diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 7d7252cd4d..ba70e7ab26 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -3,6 +3,7 @@ #include "config.hpp" #include "integral_constant.hpp" +#include "type.hpp" namespace ck { namespace math { diff --git a/composable_kernel/include/utility/Sequence.hpp b/composable_kernel/include/utility/sequence.hpp similarity index 52% rename from composable_kernel/include/utility/Sequence.hpp rename to composable_kernel/include/utility/sequence.hpp index b9327bdd81..8a9fff5979 100644 --- a/composable_kernel/include/utility/Sequence.hpp +++ b/composable_kernel/include/utility/sequence.hpp @@ -2,7 +2,9 @@ #define CK_SEQUENCE_HPP #include "integral_constant.hpp" +#include "type.hpp" #include "functional.hpp" +#include "math.hpp" namespace ck { @@ -155,8 +157,8 @@ struct Sequence static_assert(I < Size(), "wrong!"); using seq_split = sequence_split; - constexpr auto seq_left = typename seq_split::SeqType0{}; - constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); + constexpr auto seq_left = typename seq_split::left_type{}; + constexpr auto seq_right = typename seq_split::right_type{}.PopFront(); return seq_left.PushBack(Number{}).PushBack(seq_right); } @@ -188,34 +190,34 @@ struct sequence_merge }; // generate sequence -template -struct sequence_gen_impl -{ - static constexpr index_t NRemainLeft = NRemain / 2; - static constexpr index_t NRemainRight = NRemain - NRemainLeft; - static constexpr index_t IMiddle = IBegin + NRemainLeft; - - using type = - typename sequence_merge::type, - typename sequence_gen_impl::type>::type; -}; - -template -struct sequence_gen_impl -{ - static constexpr index_t Is = F{}(Number{}); - using type = Sequence; -}; - -template -struct sequence_gen_impl -{ - using type = Sequence<>; -}; - template struct sequence_gen { + template + struct sequence_gen_impl + { + static constexpr index_t NRemainLeft = NRemain / 2; + static constexpr index_t NRemainRight = NRemain - NRemainLeft; + static constexpr index_t IMiddle = IBegin + NRemainLeft; + + using type = typename sequence_merge< + typename sequence_gen_impl::type, + typename sequence_gen_impl::type>::type; + }; + + template + struct sequence_gen_impl + { + static constexpr index_t Is = G{}(Number{}); + using type = Sequence; + }; + + template + struct sequence_gen_impl + { + using type = Sequence<>; + }; + using type = typename sequence_gen_impl<0, NSize, F>::type; }; @@ -281,8 +283,8 @@ struct sequence_split using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; using range1 = typename arithmetic_sequence_gen::type; - using SeqType0 = decltype(Seq::Extract(range0{})); - using SeqType1 = decltype(Seq::Extract(range1{})); + using left_type = decltype(Seq::Extract(range0{})); + using right_type = decltype(Seq::Extract(range1{})); }; // reverse sequence @@ -293,8 +295,8 @@ struct sequence_reverse using seq_split = sequence_split; using type = typename sequence_merge< - typename sequence_reverse::type, - typename sequence_reverse::type>::type; + typename sequence_reverse::type, + typename sequence_reverse::type>::type; }; template @@ -309,138 +311,264 @@ struct sequence_reverse> using type = Sequence; }; -template -struct sequence_sort +template +struct sequence_sort_impl { - template + template struct sorted_sequence_merge_impl { - static constexpr bool pick_left = SeqLeft::Front() < SeqRight::Front(); - static constexpr index_t next_value = pick_left ? SeqLeft::Front() : SeqRight::Front(); + static constexpr bool choose_left = LeftValues::Front() < RightValues::Front(); - using new_merged_seq = decltype(MergedSeq::PushBack(Number{})); + static constexpr index_t chosen_value = + choose_left ? LeftValues::Front() : RightValues::Front(); + static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front(); - using new_left_seq = - typename conditional::type; - using new_right_seq = - typename conditional::type; + using new_merged_values = decltype(MergedValues::PushBack(Number{})); + using new_merged_ids = decltype(MergedIds::PushBack(Number{})); + + using new_left_values = + typename conditional::type; + using new_left_ids = + typename conditional::type; + + using new_right_values = + typename conditional::type; + using new_right_ids = + typename conditional::type; + + using merge = sorted_sequence_merge_impl; + // this is output + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + template + struct sorted_sequence_merge_impl, + Sequence<>, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge_impl, + Sequence<>, + RightValues, + RightIds, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge + { + using merge = sorted_sequence_merge_impl, + Sequence<>, + Comp>; + + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + static constexpr index_t nsize = Values::Size(); + + using split_unsorted_values = sequence_split; + using split_unsorted_ids = sequence_split; + + using left_unsorted_values = typename split_unsorted_values::left_type; + using left_unsorted_ids = typename split_unsorted_ids::left_type; + using left_sort = sequence_sort_impl; + using left_sorted_values = typename left_sort::sorted_values; + using left_sorted_ids = typename left_sort::sorted_ids; + + using right_unsorted_values = typename split_unsorted_values::right_type; + using right_unsorted_ids = typename split_unsorted_ids::right_type; + using right_sort = sequence_sort_impl; + using right_sorted_values = typename right_sort::sorted_values; + using right_sorted_ids = typename right_sort::sorted_ids; + + using merged_sorted = sorted_sequence_merge; + + using sorted_values = typename merged_sorted::merged_values; + using sorted_ids = typename merged_sorted::merged_ids; +}; + +template +struct sequence_sort_impl, Sequence, Compare> +{ + static constexpr bool choose_x = Compare{}(ValueX, ValueY); + + using sorted_values = + typename conditional, Sequence>::type; + using sorted_ids = typename conditional, Sequence>::type; +}; + +template +struct sequence_sort_impl, Sequence, Compare> +{ + using sorted_values = Sequence; + using sorted_ids = Sequence; +}; + +template +struct sequence_sort +{ + using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type; + using sort = sequence_sort_impl; + + // this is output + using type = typename sort::sorted_values; + using sorted2unsorted_map = typename sort::sorted_ids; +}; + +template +struct sequence_unique_sort +{ + template + struct sorted_sequence_uniquify_impl + { + static constexpr index_t current_value = RemainValues::Front(); + static constexpr index_t current_id = RemainIds::Front(); + + static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back()); + + using new_remain_values = decltype(RemainValues::PopFront()); + using new_remain_ids = decltype(RemainIds::PopFront()); + + using new_uniquified_values = + typename conditional{})), + UniquifiedValues>::type; + + using new_uniquified_ids = + typename conditional{})), + UniquifiedIds>::type; + + using uniquify = sorted_sequence_uniquify_impl; + + // this is output + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + template + struct sorted_sequence_uniquify_impl, + Sequence<>, + UniquifiedValues, + UniquifiedIds, + Eq> + { + using uniquified_values = UniquifiedValues; + using uniquified_ids = UniquifiedIds; + }; + + template + struct sorted_sequence_uniquify + { + using uniquify = sorted_sequence_uniquify_impl, + Sequence, + Eq>; + + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + using sort = sequence_sort; + using sorted_values = typename sort::type; + using sorted_ids = typename sort::sorted2unsorted_map; + + using uniquify = sorted_sequence_uniquify; + + // this is output + using type = typename uniquify::uniquified_values; + using sorted2unsorted_map = typename uniquify::uniquified_ids; +}; + +template +struct is_valid_sequence_map +{ + static constexpr bool value = + is_same::type, + typename sequence_sort>::type>{}; +}; + +template +struct sequence_map_inverse +{ + template + struct sequence_map_inverse_impl + { + static constexpr auto new_y2x = + WorkingY2X::Modify(X2Y::At(Number{}), Number{}); using type = - typename sorted_sequence_merge_impl:: + typename sequence_map_inverse_impl:: type; }; - template - struct sorted_sequence_merge_impl, MergedSeq, Comp> + template + struct sequence_map_inverse_impl { - using type = typename sequence_merge::type; + using type = WorkingY2X; }; - template - struct sorted_sequence_merge_impl, SeqRight, MergedSeq, Comp> - { - using type = typename sequence_merge::type; - }; - - template - struct sorted_sequence_merge - { - using type = typename sorted_sequence_merge_impl, Comp>::type; - }; - - using split = sequence_split; - using unsorted_left = typename split::SeqType0; - using unsorted_right = typename split::SeqType1; - - using sorted_left = typename sequence_sort::type; - using sorted_right = typename sequence_sort::type; - - using type = typename sorted_sequence_merge::type; -}; - -template -struct sequence_sort, Compare> -{ - static constexpr bool x_first = Compare{}(X, Y); - - using type = typename conditional, Sequence>::type; -}; - -template -struct sequence_sort, Compare> -{ - using type = Sequence; -}; - -template -struct sequence_unique_sort -{ - template - struct sorted_sequence_uniquify_impl - { - static constexpr index_t new_value = WorkInputSeq::Front(); - using new_work_input_seq = decltype(WorkInputSeq::PopFront()); - - using new_working_output_seq = - typename conditional{}))>::type; - }; - - template - struct sorted_sequence_uniquify_impl, Eq> - { - using type = WorkInputSeq; - }; - - template - struct sorted_sequence_uniquify - { - using type = typename sorted_sequence_uniquify_impl, Eq>::type; - }; - - using sorted_seq = typename sequence_sort::type; - - using type = typename sorted_sequence_uniquify::type; -}; - -template -struct is_valid_sequence_map -{ - // not implemented yet, always return true - static constexpr integral_constant value = integral_constant{}; - - // TODO: add proper check for is_valid, something like: - // static constexpr bool value = - // is_same::type, - // typename sequence_sort::SortedSeqType>{}; -}; - -template -struct sequence_map_inverse_impl -{ - private: - static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number{}), Number{}); - - public: using type = - typename sequence_map_inverse_impl::type; -}; - -template -struct sequence_map_inverse_impl -{ - using type = WorkingY2X; -}; - -template -struct sequence_map_inverse -{ - using type = - typename sequence_map_inverse_impl::type, + typename sequence_map_inverse_impl::type, 0, - X2Y::Size()>::type; + SeqMap::Size()>::type; }; template @@ -601,6 +729,12 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number{}).Reverse(); } +template +__host__ __device__ constexpr auto pick_sequence_elements(Seq, Sequence) +{ + return Sequence{})...>{}; +} + template struct lambda_accumulate_on_sequence { diff --git a/composable_kernel/include/utility/sequence_helper.hpp b/composable_kernel/include/utility/sequence_helper.hpp index c499d6e9a2..71abfea1fe 100644 --- a/composable_kernel/include/utility/sequence_helper.hpp +++ b/composable_kernel/include/utility/sequence_helper.hpp @@ -1,12 +1,12 @@ #ifndef CK_SEQUENCE_HELPER_HPP #define CK_SEQUENCE_HELPER_HPP -#include "Sequence.hpp" +#include "sequence.hpp" namespace ck { template -__host__ __device__ void print_Sequence(const char* s, Sequence) +__host__ __device__ void print_sequence(const char* s, Sequence) { constexpr index_t nsize = Sequence::Size(); diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp index 3fa9d0fccd..c26cad2ae6 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/composable_kernel/include/utility/tuple.hpp @@ -3,7 +3,7 @@ #include "integral_constant.hpp" #include "type.hpp" -#include "Sequence.hpp" +#include "sequence.hpp" namespace ck { @@ -114,19 +114,19 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs) namespace detail { -template -__host__ __device__ constexpr auto transpose_tuple_impl(X& x, F f, Sequence) +template +__host__ __device__ constexpr auto transform_tuple_impl(F f, const X& x, Sequence) { return make_tuple(f(x.At(Number{}))...); } } // namespace detail -template -__host__ __device__ constexpr auto transpose_tuple(X& x, F f) +template +__host__ __device__ constexpr auto transform_tuple(F f, const X& x) { - return detail::transpose_tuple_impl( - x, f, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); + return detail::transform_tuple_impl( + f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); } } // namespace ck diff --git a/composable_kernel/include/utility/type.hpp b/composable_kernel/include/utility/type.hpp index 98a4640027..ac6d306d7f 100644 --- a/composable_kernel/include/utility/type.hpp +++ b/composable_kernel/include/utility/type.hpp @@ -2,10 +2,12 @@ #define CK_TYPE_HPP #include "integral_constant.hpp" -#include "Sequence.hpp" namespace ck { +template +struct Sequence; + template struct is_same : public integral_constant { diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index e25426d812..b20fc26f78 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -84,8 +84,8 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - constexpr index_t HPad = 2; - constexpr index_t WPad = 2; + constexpr index_t HPad = 3; + constexpr index_t WPad = 3; #elif 1 // 3x3, 34x34 constexpr index_t N = 64;